diff --git a/README.md b/README.md index 05fcb23f7edd657f2ea495d848fadc226e56b524..bbebfc723a65ef1365362cb56298a9399c85e179 100644 --- a/README.md +++ b/README.md @@ -81,13 +81,13 @@ The TensorFlow project strives to abide by generally accepted best practices in | Build Type | Status | Artifacts | | --- | --- | --- | -| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Linux XLA** | TBA | TBA | -| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) | +| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg) | TBA | +| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | ### Community Supported Builds @@ -97,7 +97,8 @@ The TensorFlow project strives to abide by generally accepted best practices in | **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | | **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | -| **Linux CPU with Intel® MKL-DNN®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA | +| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | +| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6| ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)
[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)
[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) | ## For more information diff --git a/RELEASE.md b/RELEASE.md index 6b67072f8ecafa08c747f8296c7c2a59eb2350fa..078aafd3746e5ce5c16af15de80d99c1a9e8c567 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,68 @@ +# Release 1.10.0 + +## Major Features And Improvements + +* The `tf.lite` runtime now supports `complex64`. +* Initial Bigtable integration for `tf.data`. +* Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation. +* `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`. +* Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018. +* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. See below for the complete list. New symbols have been added to the following modules: [`tf.debugging`](https://www.tensorflow.org/versions/master/api_docs/python/tf/debugging), [`tf.dtypes`](https://www.tensorflow.org/versions/master/api_docs/python/tf/dtypes), [`tf.image`](https://www.tensorflow.org/versions/master/api_docs/python/tf/image), [`tf.io`](https://www.tensorflow.org/versions/master/api_docs/python/tf/io), [`tf.linalg`](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), [`tf.manip`](https://www.tensorflow.org/versions/master/api_docs/python/tf/manip), [`tf.math`](https://www.tensorflow.org/versions/master/api_docs/python/tf/math), [`tf.quantization`](https://www.tensorflow.org/versions/master/api_docs/python/tf/quantization), [`tf.strings`](https://www.tensorflow.org/versions/master/api_docs/python/tf/strings) + +## Breaking Changes + +* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites). +* Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake. + +## Bug Fixes and Other Changes + +* `tf.data`: + * `tf.contrib.data.group_by_reducer()` is now available via the public API. + * `tf.contrib.data.choose_from_datasets()` is now available via the public API. + * Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`. +* `tf.estimator`: + * `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export. + * `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases. + * Support sparse_combiner in canned Linear Estimators. + * Added batch normalization to `DNNClassifier`, `DNNRegressor`, and `DNNEstimator`. + * Adding ranking support for boosted trees. + * Adding center bias option for boosted trees. +* Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables. +* Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables. +* `tf.losses.*` do not add to the global collection when executing eagerly (to avoid leaking memory). +* Support different summary and checkpoint directories in `tf.train.MonitoredTrainingSession()`. +* Added IndRNN, IndyGRU, and IndyLSTM cells to `tf.contrib.rnn`. +* Add safe static factory functions for SparseTensor and convert all CHECKs to DCHECKs. Using the constructor directly is unsafe and deprecated. +* Make the Bigtable client connection pool configurable & increase the default # of connections for performance. +* Added derivative of `tf.random_gamma` with respect to the alpha parameter. +* Added derivative of `tf.igamma(a, x)` and `tf.igammac(a, x)` with respect to a. +* Modified Bessel functions of order zero and one. +* Add FillTriangular Bijector to create triangular matrices. +* Added support for Type III DCT, and `tf.spectral.idct(type=2|3)`. +* Correctly handle CuDNN RNN weight loaded when nest in `TimeDistributed`. +* Adding per-element weight support for `WALSComputePartialLhsAndRhsOp`. +* ZerosLike and OnesLike ops treated as constants by Graph Transform Tool. +* Gamma distribution and the derived distributions (Beta, Dirichlet, Student's t, inverse Gamma) now fully reparameterized. +* Java: Experimental wrapper classes to make graph generation easier. Thanks @karllessard and @kbsriram +* Build & link in secure gRPC components (switch from the insecure grpc dependency to secure grpc dependency). +* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. List of new endpoints: + * New endpoints in `tf.image` namespace: `tf.image.extract_image_patches` + * New endpoints in `tf.debugging` namespace: `tf.debugging.check_numerics`, `tf.debugging.is_finite`, `tf.debugging.is_inf`, `tf.debugging.is_nan`. + * New endpoints in `tf.dtypes` namespace: `tf.dtypes.as_string`. + * New endpoints in `tf.io` namespace: `tf.io.decode_base64`, `tf.io.decode_compressed`, `tf.io.decode_json_example`, `tf.io.decode_raw`, `tf.io.encode_base64`, `tf.io.matching_files`, `tf.io.parse_tensor`, `tf.io.read_file, `tf.io.write_file`. + * New endpoints in tf.linalg namespace: `tf.linalg.cross`, `tf.linalg.tensor_diag` (corresponds to `tf.diag`), `tf.linalg.tensor_diag_part` (corresponds to `tf.diag_part`). + * New endpoints in tf.manip namespace: `tf.manip.batch_to_space_nd`, `tf.manip.gather_nd`, `tf.manip.reshape`, `tf.manip.reverse`, `tf.manip.scatter_nd`, `tf.manip.space_to_batch_nd`, `tf.manip.tile` + * New endpoints in tf.math namespace: `tf.math.acos`, `tf.math.acosh`, `tf.math.add`, `tf.math.asin`, `tf.math.asinh`, `tf.math.atan`, `tf.math.atan2`, `tf.math.atanh`, `tf.math.betainc`, `tf.math.ceil`, `tf.math.cos`, `tf.math.cosh`, `tf.math.digamma`, `tf.math.equal`, `tf.math.erfc`, `tf.math.exp`, `tf.math.expm1`, `tf.math.floor`, `tf.math.greater`, `tf.math.greater_equal`, `tf.math.igamma`, `tf.math.igammac`, `tf.math.invert_permutation`, `tf.math.less`, `tf.math.less_equal`, `tf.math.lgamma`, `tf.math.log`, `tf.math.log1p`, `tf.math.logical_and`, `tf.math.logical_not`, `tf.math.logical_or`, `tf.math.maximum`, `tf.math.minimum`, `tf.math.not_equal`, `tf.math.polygamma`, `tf.math.reciprocal`, `tf.math.rint`, `tf.math.rsqrt`, `tf.math.segment_max`, `tf.math.segment_mean`, `tf.math.segment_min`, `tf.math.segment_prod`, `tf.math.segment_sum`, `tf.math.sin`, `tf.math.sinh`, `tf.math.softplus`, `tf.math.softsign`, `tf.math.squared_difference`, `tf.math.tan`, `tf.math.unsorted_segment_max`, `tf.math.unsorted_segment_min`, `tf.math.unsorted_segment_prod`, `tf.math.unsorted_segment_sum`, `tf.math.zeta`. + * New endpoints in `tf.quantization` namespace: `tf.quantization.dequantize`, `tf.quantization.fake_quant_with_min_max_args`, `tf.quantization.fake_quant_with_min_max_args_gradient`, `tf.quantization.fake_quant_with_min_max_vars`, `tf.quantization.fake_quant_with_min_max_vars_gradient`, `tf.quantization.fake_quant_with_min_max_vars_per_channel`, `tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient`. + * New endpoints in tf.strings namespace: `tf.strings.join` (corresponds to `tf.string_join`), `tf.strings.regex_replace`, `tf.strings.to_number` (corresponds to `tf.string_to_number`), `tf.strings.strip` (corresponds to `tf.string_strip`), `tf.strings.substr`, `tf.strings.to_hash_bucket` (corresponds to `tf.string_to_hash_bucket`), `tf.strings.to_hash_bucket_fast` (corresponds to `tf.string_to_hash_bucket_fast`), `tf.strings.to_hash_bucket_strong` (corresponds to `tf.string_to_hash_bucket_strong`). + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, Andrei Nigmatulin, Andrew Ginns, BjøRn Moholt, Brett Koonce, Chengzhi Chen, Chinmay Das, Christian Ertler, Christoph Boeddeker, Clayne Robison, Courtial Florian, ctiijima, Dan Douthit, Dan J, Dan Ringwalt, EFanZh, Emanuele Ballarin, eqy, Evgeniy Zheltonozhskiy, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, G K, gracehoney, Guillaume Klein, Guozhong Zhuang, Hsien-Yang Li, hsm207, ImSheridan, Jayaram Bobba, Jiandong Ruan, Jie, Joel Shor, Jonas Rauber, Jongmin Baek, jsawruk, Karan Kaw, Karl Lessard, karl@kubx.ca, Kb Sriram, KinmanLam, leiiwang, Li, Yiqiang, Loo Rong Jie, Mahmoud Abuzaina, Mahmoud Aslan, ManHyuk, Martin Patz, Martin Zeitler, mktozk, Mohammad Ashraf Bhuiyan, mrTsjolder, Naman Bhalla, Nick Felt, Nicolas Lopez, Niranjan Hasabnis, Nishidha Panpaliya, Nitish, nrstott, Nutti, Parag Jain, PeterLee, Philipp Jund, Rach L, Rafal Wojdyla, Roland Zimmermann, Sergei Lebedev, SneakyFish5, Soila Kavulya, Sriram Veturi, Steven Schmatz, Taehoon Lee, Tang, Wenyi, Taras Sereda, Ted Chang, Tim Zaman, Tristan Rice, tucan, vchigrin, Vikram Tiwari, Vincent, WeberXie, William D. Irons, Yan Facai (颜发才), Yong Tang, Yu Yi, Yuxin Wu, Zé ViníCius + # Release 1.9.0 ## Major Features And Improvements diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 388ca3f293ebfa120037b75fe70c66b9d715c051..f8cd6820244aa05724ce0980419eb7b77962ff91 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -381,6 +381,14 @@ config_setting( }, ) +# Setting to use when loading kernels dynamically +config_setting( + name = "dynamic_loaded_kernels", + define_values = { + "dynamic_loaded_kernels": "true", + }, +) + config_setting( name = "using_cuda_nvcc", define_values = { @@ -408,14 +416,6 @@ config_setting( visibility = ["//visibility:public"], ) -# TODO(laigd): consider removing this option and make TensorRT enabled -# automatically when CUDA is enabled. -config_setting( - name = "with_tensorrt_support", - values = {"define": "with_tensorrt_support=true"}, - visibility = ["//visibility:public"], -) - package_group( name = "internal", packages = [ @@ -441,11 +441,6 @@ filegroup( ), ) -filegroup( - name = "docs_src", - data = glob(["docs_src/**/*.md"]), -) - cc_library( name = "grpc", deps = select({ @@ -589,6 +584,7 @@ exports_files( gen_api_init_files( name = "tensorflow_python_api_gen", srcs = ["api_template.__init__.py"], + api_version = 1, root_init_template = "api_template.__init__.py", ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 10bc8cdbee5a9df6d2084c10adab4ed6e5e6f0d3..19ccb6e71d2f3021c1ce5c8905d8a72059c1cfcb 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -2389,6 +2390,12 @@ void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { + TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy); +} + +void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, + int ny, TF_Output* x, int nx, TF_Output* dx, + TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ status->status = tensorflow::errors::Unimplemented( "Adding gradients is not supported in Android. File a bug at " @@ -2405,9 +2412,29 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, const int first_new_node_id = g->graph.num_node_ids(); + string prefix_cmp; + const char* child_scope_name; + if (prefix == nullptr) { + child_scope_name = "gradients"; + } else { + prefix_cmp = string(prefix) + "/"; + // The operation should fail if the provided name prefix has already been + // used in this graph + for (const auto& pair : g->name_map) { + const string& name = pair.first; + if (name.compare(prefix) == 0 || + tensorflow::str_util::StartsWith(name, prefix_cmp)) { + status->status = InvalidArgument( + "prefix [", prefix, + "] conflicts with existing node in the graph named [", name, "]"); + return; + } + } + child_scope_name = prefix; + } tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) - .NewSubScope("gradients"); + .NewSubScope(child_scope_name); if (dx != nullptr) { std::vector dx_arg = OutputsFromTFOutputs(dx, ny); @@ -2422,6 +2449,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; + + // Adding the gradients to the graph can alter the prefix to prevent + // name collisions only if this prefix has not been provided explicitly + // by the user. If it was provided, assert that it remained intact. + if (prefix != nullptr && + !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) { + status->status = tensorflow::errors::Internal( + "BUG: The gradients prefix have been unexpectedly altered when " + "adding the nodes to the graph. This is a bug. Please file an " + "issue at https://github.com/tensorflow/tensorflow/issues."); + return; + } // We have a convoluted scheme here: Using the C++ graph construction API // to add potentially many nodes to the graph without running the checks // (such as uniqueness of the names of nodes) we run with other functions diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c8ae6f2dd1780c4fe50ff1924be8d2e9a7502cf0..850f6ecd637d768bca99720e0add07680829e17a 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1131,6 +1131,7 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, // i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// // `dx` are used as initial gradients (which represent the symbolic partial // derivatives of some loss function `L` w.r.t. `y`). // `dx` must be nullptr or have size `ny`. @@ -1139,6 +1140,12 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // The partial derivatives are returned in `dy`. `dy` should be allocated to // size `nx`. // +// Gradient nodes are automatically named under the "gradients/" prefix. To +// guarantee name uniqueness, subsequent calls to the same graph will +// append an incremental tag to the prefix: "gradients_1/", "gradients_2/", ... +// See TF_AddGradientsWithPrefix, which provides a means to specify a custom +// name prefix for operations added to a graph to compute the gradients. +// // WARNING: This function does not yet support all the gradients that python // supports. See // https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md @@ -1147,6 +1154,33 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy); +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// This is a variant of TF_AddGradients that allows to caller to pass a custom +// name prefix to the operations added to a graph to compute the gradients. +// +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// `prefix` names the scope into which all gradients operations are being added. +// `prefix` must be unique within the provided graph otherwise this operation +// will fail. If `prefix` is nullptr, the default prefixing behaviour takes +// place, see TF_AddGradients for more details. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, + TF_Output* y, int ny, + TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, + TF_Output* dy); + // Create a TF_Function from a TF_Graph // // Params: @@ -1236,6 +1270,11 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); +// Returns the name of the graph function. +// The return value points to memory that is only usable until the next +// mutation to *func. +TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); + // Write out a serialized representation of `func` (as a FunctionDef protocol // message) to `output_func_def` (allocated by TF_NewBuffer()). // `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 170046c8024dc85c899108b254cd3a95a3be4096..69b3ffe2a1f620e346405607ecf742fb863aa644 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -84,6 +84,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, return ret; } +TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) { + tensorflow::RunOptions options; + if (enable_full_trace) { + options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); + } else { + options.set_trace_level(tensorflow::RunOptions::NO_TRACE); + } + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(options, ret)); + return ret; +} + const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { tensorflow::mutex_lock c(graph->mu); const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 2d81c01e0dd056e9beb3b45f24809381554a7924..6617c5a572e90e78369f73d714f39942f213040f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -70,6 +70,12 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth); +// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level +// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE +// otherwise. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( + unsigned char enable_full_trace); + // Returns the graph content in a human-readable format, with length set in // `len`. The format is subject to change in the future. // The returned string is heap-allocated, and caller should call free() on it. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 384e6c8cb97022264c5327da5ca5861057608fbe..a2c5a42c11361779de61b515e0f08dcc45e609b9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -536,6 +536,10 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, return tf_function; } +const char* TF_FunctionName(TF_Function* func) { + return func->fdef.signature().name().c_str(); +} + void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, const TF_Function* grad, TF_Status* status) { if (func == nullptr) { diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index f7ca219c896b2a7c07fc4d0739c70f2666652672..73fe73769bc1219ce865149d67d333c53371ccc5 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -193,6 +193,7 @@ class CApiFunctionTest : public ::testing::Test { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); ASSERT_NE(func_, nullptr); + ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_))); TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } @@ -1618,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } +// This test only works when the TF build includes XLA compiler. One way to set +// this up is via bazel build option "--define with_xla_support=true". +// +// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to +// something like TENSORFLOW_CAPI_USE_XLA. +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST_F(CApiFunctionTest, StatelessIf_XLA) { + TF_Function* func; + const std::string funcName = "BranchFunc"; + DefineFunction(funcName.c_str(), &func); + TF_GraphCopyFunction(host_graph_, func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* feed = Placeholder(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); + TF_AddInput(desc, {true_cond, 0}); + TF_Output inputs[] = {{feed, 0}}; + TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_SetAttrType(desc, "Tcond", TF_BOOL); + TF_DataType inputType = TF_INT32; + TF_SetAttrTypeList(desc, "Tin", &inputType, 1); + TF_SetAttrTypeList(desc, "Tout", &inputType, 1); + TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); + TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); + TF_SetDevice(desc, "/device:XLA_CPU:0"); + auto op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(op, nullptr); + + // Create a session for this graph. + CSession csession(host_graph_, s_, /*use_XLA*/ true); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor(17)}}); + csession.SetOutputs({op}); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(-17, *output_contents); + + // Clean up + csession.CloseAndDelete(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_DeleteFunction(func); +} +#endif // TENSORFLOW_EAGER_USE_XLA + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index e674b1623cf540eb8024d9be5ed8d77aa2fe17ba..aa2a537f03be31ae45ff3d6f7815b449d661cf9c 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1483,8 +1483,8 @@ class CApiGradientsTest : public ::testing::Test { BuildSuccessGraph(inputs, outputs); BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); - AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); - + AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1, + grad_outputs); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); // Compare that the graphs match. @@ -1505,7 +1505,8 @@ class CApiGradientsTest : public ::testing::Test { BuildErrorGraph(inputs, outputs); - AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1, + grad_outputs); string expected_msg = "No gradient defined for op: TestOpWithNoGradient. Please see " @@ -1549,19 +1550,20 @@ class CApiGradientsTest : public ::testing::Test { EXPECT_EQ(*a_data, *b_data); } - void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, - TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + void AddGradients(bool grad_inputs_provided, const char* prefix, + TF_Output* inputs, int ninputs, TF_Output* outputs, + int noutputs, TF_Output* grad_outputs) { if (grad_inputs_provided) { TF_Output grad_inputs[1]; const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; TF_Operation* grad_inputs_op = FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); grad_inputs[0] = TF_Output{grad_inputs_op, 0}; - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, - s_, grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, grad_inputs, s_, grad_outputs); } else { - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, - grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, nullptr, s_, grad_outputs); } } @@ -1706,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test { return op; } + void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1, + const char* prefix2 = nullptr) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + + AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs); + if (prefix2 != nullptr) { + AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs); + } + } + TF_Status* s_; TF_Graph* graph_; TF_Graph* expected_graph_; @@ -1725,6 +1741,56 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) { + BuildGraphAndAddGradientsWithPrefixes("gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) { + BuildGraphAndAddGradientsWithPrefixes("Const_0"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) { + BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + void ScalarFloatFromTensor(const TF_Tensor* t, float* f) { ASSERT_TRUE(t != nullptr); ASSERT_EQ(TF_FLOAT, TF_TensorType(t)); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 24eb6c069b21349fce288db3e79fbf14e824ad11..f15d9ee20adb31a0b76e2cd0d1e67f17a9deff05 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -26,6 +26,10 @@ limitations under the License. using tensorflow::GraphDef; using tensorflow::NodeDef; +static void BoolDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } @@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +TF_Tensor* BoolTensor(bool v) { + const int num_bytes = sizeof(bool); + bool* values = new bool[1]; + values[0] = v; + return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator, + nullptr); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, return op; } +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name) { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 38313d647ca93d4779bb1325f8ed7bde4b743879..7eeb1ee5e17ad7e5644f8bc8a18ca967b108475d 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -31,6 +31,8 @@ using ::tensorflow::string; typedef std::unique_ptr unique_tensor_ptr; +TF_Tensor* BoolTensor(int32_t v); + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); @@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 6c510536d6f2a586b91baf96fa41b779db2c8d35..dfb1c9a37644c726e1eabab775593596d5b556b9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices( tensorflow::Status CreateRemoteContexts( const std::vector& remote_workers, int64 rendezvous_id, - const tensorflow::ServerDef& server_def, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); auto* eager_client = remote_eager_workers->GetClient(remote_worker); if (eager_client == nullptr) { return tensorflow::errors::Internal( @@ -150,8 +151,9 @@ tensorflow::Status CreateRemoteContexts( return tensorflow::Status::OK(); } -tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, - TFE_Context** ctx) { +tensorflow::Status UpdateTFE_ContextWithServerDef( + int keep_alive_secs, const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -165,12 +167,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, } \ } while (0); - string worker_name = tensorflow::strings::StrCat( - "/job:", opts->server_def.job_name(), - "/replica:0/task:", opts->server_def.task_index()); + string worker_name = + tensorflow::strings::StrCat("/job:", server_def.job_name(), + "/replica:0/task:", server_def.task_index()); std::unique_ptr server; - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server)); + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); tensorflow::GrpcServer* grpc_server = dynamic_cast(server.get()); @@ -202,15 +204,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, opts->server_def, - remote_eager_workers.get(), opts->async, &remote_contexts)); + remote_workers, rendezvous_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, opts->server_def, true)); + session_name, server_def, true)); std::shared_ptr worker_session; TF_RETURN_IF_ERROR( @@ -221,10 +223,11 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); auto* device_mgr = grpc_server->worker_env()->device_mgr; - *ctx = new TFE_Context(opts->session_options.options, opts->policy, - opts->async, device_mgr, r, std::move(server), - std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts); + + ctx->context.InitializeRemote(std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, + r, device_mgr, keep_alive_secs); return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -249,15 +252,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status) { - if (!options->server_def.ParseFromArray(proto, proto_len)) { - status->status = tensorflow::errors::InvalidArgument( - "Invalid tensorflow.ServerDef protocol buffer"); - } -} - TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -267,12 +261,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - if (!opts->server_def.job_name().empty()) { - TFE_Context* ctx = nullptr; - status->status = NewRemoteAwareTFE_Context(opts, &ctx); - return ctx; - } - std::vector devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", @@ -288,7 +276,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { opts->async, std::move(device_mgr), r); } -void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } +void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; @@ -301,6 +289,22 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } +// Set server_def on the context, possibly updating it. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = + UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); +} + void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { ctx->context.SetThreadLocalDevicePlacementPolicy( @@ -336,7 +340,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { } void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { - DCHECK(h); + if (h == nullptr) return; if (h->handle) { h->handle->Unref(); } @@ -348,6 +352,11 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } int result; status->status = h->handle->NumDims(&result); return result; @@ -355,12 +364,22 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } tensorflow::int64 result; status->status = h->handle->Dim(dim_index, &result); return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } tensorflow::Device* d = nullptr; status->status = h->handle->OpDevice(&d); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" @@ -368,6 +387,11 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } // TODO(agarwal): move this implementation inside TFE_TensorHandle. tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; @@ -700,6 +724,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } + +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } + namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index fdbd5374b2afe815c3a81b453930eb8f1fa351d3..a0ebc6fa0a22ed61be91c2974352c2988fb4cd92 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); -// A tensorflow.ServerDef specifies remote workers (in addition to the current -// workers name). Operations created on this context can then be executed on -// any of these remote workers by setting an appropriate device. -// -// If the following is set, all servers identified by the -// ServerDef must be up when the context is created. -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status); - // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); @@ -102,8 +92,7 @@ typedef struct TFE_Context TFE_Context; TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( const TFE_ContextOptions* opts, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, - TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); @@ -128,6 +117,18 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, unsigned char async, TF_Status* status); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + // Causes the calling thread to block till all ops dispatched in async mode // have been executed. Note that "execution" here refers to kernel execution / // scheduling of copies, etc. Similar to sync execution, it doesn't guarantee @@ -380,6 +381,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Some TF ops need a step container to be set to limit the lifetime of some +// resources (mostly TensorArray and Stack, used in while loop gradients in +// graph mode). Calling this on a context tells it to start a step. +TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx); + +// Ends a step. When there is no active step (that is, every started step has +// been ended) step containers will be cleared. Note: it is not safe to call +// TFE_ContextEndStep while ops which rely on the step container may be running. +TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 4c5077023d5bb3b83808bf3908e7110dd026e3ad..a5c0681e2e4eddae08954d9d0178ca96a3f8f29a 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -59,7 +59,6 @@ struct TFE_ContextOptions { // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; - tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -73,23 +72,6 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} - explicit TFE_Context( - const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, bool async, - tensorflow::DeviceMgr* local_device_mgr, - tensorflow::Rendezvous* rendezvous, - std::unique_ptr server, - std::unique_ptr remote_eager_workers, - std::unique_ptr remote_device_mgr, - const tensorflow::gtl::FlatMap& - remote_contexts) - : context(opts, - static_cast( - default_policy), - async, local_device_mgr, rendezvous, std::move(server), - std::move(remote_eager_workers), std::move(remote_device_mgr), - remote_contexts) {} - tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3504a8b5e78480732d3454097c1b2197ac2b2e17..71d5f3613c89762633113b4e1dfb82b8199a1cd1 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -49,7 +49,7 @@ void BM_InitOp(int iters) { } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -80,7 +80,7 @@ void BM_Execute(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -95,7 +95,7 @@ TEST(CAPI, Context) { TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const int num_devices = TF_DeviceListCount(devices); @@ -108,14 +108,14 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } -tensorflow::ServerDef GetServerDef(int num_tasks) { +tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { tensorflow::ServerDef server_def; server_def.set_protocol("grpc"); - server_def.set_job_name("localhost"); + server_def.set_job_name(job_name); server_def.set_task_index(0); tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name("localhost"); + job_def->set_name(job_name); for (int i = 0; i < num_tasks; i++) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( @@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) { return server_def; } +tensorflow::ServerDef GetServerDef(int num_tasks) { + return GetServerDef("localhost", num_tasks); +} + void TestRemoteExecute(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); @@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char remote_device_name[] = @@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) { TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); @@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; @@ -281,7 +285,7 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } +void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, + const std::vector& expected_values) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + std::unique_ptr actual_values(new float[expected_values.size()]); + EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); + memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + + for (int i = 0; i < expected_values.size(); i++) { + EXPECT_EQ(expected_values[i], actual_values[i]) + << "Mismatch in expected values at (zero-based) index " << i; + } +} + +void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, + const char* remote_device_name, + const char* local_device_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = + TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); + + TFE_DeleteTensorHandle(retval_task0); + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +void TestRemoteExecuteChangeServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); + + // Update the server def with a new set of names (worker instead of + // localhost). + tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); + serialized = updated_server_def.SerializeAsString(); + + updated_server_def.set_task_index(1); + tensorflow::Status s = tensorflow::GrpcServer::Create( + updated_server_def, tensorflow::Env::Default(), &worker_server); + ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(worker_server->Start().ok()); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Create a new tensor_handle. + TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(); + + // Check that copying it to the old remote device (named localhost) fails. + TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Copying and executing on the new remote device works. + const char new_remote_device_name[] = + "/job:worker/replica:0/task:1/device:CPU:0"; + const char new_local_device_name[] = + "/job:worker/replica:0/task:0/device:CPU:0"; + + auto* h0_task1_new = TFE_TensorHandleCopyToDevice( + h0_task0_new, ctx, new_remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0_new); + TFE_DeleteTensorHandle(h0_task1_new); + + CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, + new_local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + TFE_DeleteContext(ctx); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteChangeServerDef) { + TestRemoteExecuteChangeServerDef(false); +} +TEST(CAPI, RemoteExecuteChangeServerDefAsync) { + TestRemoteExecuteChangeServerDef(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -380,8 +525,7 @@ void TensorHandleCopyBetweenDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenDevices) { @@ -418,7 +562,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) { TFE_DeleteTensorHandle(hcopy); TFE_DeleteTensorHandle(hcpu); if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); - TFE_DeleteContext(ctx, status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenDevicesError) { @@ -451,7 +595,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); + TFE_DeleteContext(ctx); return; } const string gpu_1_name(TF_DeviceListName(devices, 1, status.get())); @@ -484,8 +628,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { @@ -533,8 +676,7 @@ void TensorHandleSilentCopy(bool async) { TFE_DeleteTensorHandle(hcpu); TFE_ContextAsyncWait(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); } @@ -580,8 +722,7 @@ void TensorHandleSilentCopyLocal(bool async) { TFE_DeleteTensorHandle(hcpu); TFE_ContextAsyncWait(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } TEST(CAPI, TensorHandleSilentCopyLocalAsync) { @@ -614,11 +755,47 @@ void SetAndGetOpDevices(bool async) { TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } +TEST(CAPI, TensorHandleNullptr) { + TFE_TensorHandle* h = nullptr; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TF_Tensor* t = TFE_TensorHandleResolve(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(t, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + const char* device_name = TFE_TensorHandleDeviceName(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_name, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int num_dims = TFE_TensorHandleNumDims(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(num_dims, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int dim = TFE_TensorHandleDim(h, 0, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(dim, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); +} + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -640,7 +817,7 @@ void Execute_MatMul_CPU(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -712,7 +889,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_DeleteTensorHandle(m1); TFE_DeleteTensorHandle(m2); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) { @@ -743,7 +920,7 @@ void Execute_MatMul_CPU_Type_Error(bool async) { if (retvals[0] != nullptr) { TFE_DeleteTensorHandle(retvals[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } @@ -781,7 +958,7 @@ TEST(CAPI, Execute_Min_CPU) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -823,7 +1000,7 @@ void Execute_MatMul_XLA_CPU(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } @@ -862,7 +1039,7 @@ void Execute_Min_XLA_CPU(bool async) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } @@ -898,7 +1075,7 @@ void ExecuteWithTracing(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -974,7 +1151,7 @@ TEST(CAPI, Function_ident_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } @@ -1044,7 +1221,7 @@ TEST(CAPI, Function_ident_XLA_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } @@ -1120,7 +1297,7 @@ void FunctionDefAndExecute(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1161,7 +1338,7 @@ void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1249,7 +1426,7 @@ TEST(CAPI, Variables) { TFE_DeleteTensorHandle(var_handle); TFE_DeleteTensorHandle(value_handle); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1288,7 +1465,7 @@ void BM_ReadVariable(int iters) { TFE_DeleteOp(op); TFE_DeleteTensorHandle(var_handle); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a98f0b00b2c70055f697ed4f15cb14708384b62f..588a45ea43f90c4d9b3d04fea305d2c562ae1d72 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -121,6 +121,7 @@ cc_library( deps = [ ":array_grad", ":data_flow_grad", + ":image_grad", ":math_grad", ":nn_grad", ], @@ -331,6 +332,36 @@ tf_cc_test( ], ) +cc_library( + name = "image_grad", + srcs = ["gradients/image_grad.cc"], + deps = [ + ":cc_ops", + ":cc_ops_internal", + ":grad_op_registry", + ":gradients", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "gradients_image_grad_test", + srcs = ["gradients/image_grad_test.cc"], + deps = [ + ":cc_ops", + ":client_session", + ":grad_op_registry", + ":grad_testutil", + ":gradient_checker", + ":image_grad", + ":testutil", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "math_grad", srcs = ["gradients/math_grad.cc"], diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index ba056a8f3a84910aebf5079573cb64c19f41469d..0e61089a5950ee894ad5489317757cff8a85e966 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, target_node_names, outputs, run_metadata); } +Status ClientSession::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); + return impl()->session_->MakeCallable(callable_options, out_handle); +} + +Status ClientSession::RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { + return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors, + run_metadata); +} + +Status ClientSession::ReleaseCallable(CallableHandle handle) { + return impl()->session_->ReleaseCallable(handle); +} + } // end namespace tensorflow diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index 5fb4109f7d15d5997f745acd913e60a02855fd73..7dd653eec4ec729b652cb779d06e820bfb437b3c 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -87,7 +87,33 @@ class ClientSession { const std::vector& run_outputs, std::vector* outputs, RunMetadata* run_metadata) const; - // TODO(keveman): Add support for partial run. + /// \brief A handle to a subgraph, created with + /// `ClientSession::MakeCallable()`. + typedef int64 CallableHandle; + + /// \brief Creates a `handle` for invoking the subgraph defined by + /// `callable_options`. + /// NOTE: This API is still experimental and may change. + Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle); + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. + /// + /// The order of tensors in `feed_tensors` must match the order of names in + /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will + /// match the order of names in `CallableOptions::fetch()` when this subgraph + /// was created. + /// NOTE: This API is still experimental and may change. + Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata); + + /// \brief Releases resources associated with the given `handle` in this + /// session. + /// NOTE: This API is still experimental and may change. + Status ReleaseCallable(CallableHandle handle); private: class Impl; diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index ea5cf5a1f12be316cc6e0d0a02cd3caf4d177400..559ffea7e817526e7f1396cd0e8187d01364f23b 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); } +TEST(ClientSessionTest, Callable) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto b = Placeholder(root, DT_INT32); + auto c = Add(root, a, b); + ClientSession session(root); + std::vector outputs; + + CallableOptions options; + options.add_feed(a.node()->name()); + options.add_feed(b.node()->name()); + options.add_fetch(c.node()->name()); + ClientSession::CallableHandle callable; + TF_CHECK_OK(session.MakeCallable(options, &callable)); + TF_EXPECT_OK(session.RunCallable( + callable, {test::AsTensor({1}, {}), test::AsTensor({41}, {})}, + &outputs, nullptr)); + test::ExpectTensorEqual(outputs[0], test::AsTensor({42}, {})); + TF_EXPECT_OK(session.ReleaseCallable(callable)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index de2645cb440bda1f35e764af9197ca97bb760c08..e9f9c59e3aa0e8a9dc5d5e658540e9da73adaca5 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, auto y_pos_flat = y_pos[y_idx].flat(); auto y_neg_flat = y_neg[y_idx].flat(); const int64 y_size = y_shapes[y_idx].num_elements(); - const Y_T scale = Y_T{2 * delta}; + const Y_T scale = 2 * delta; auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); for (int c = 0; c < y_size; ++c) { SetJacobian(&jacobian, r * x_stride + unit_dimension, @@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, auto jac_n = jacobian_ns[i].matrix(); for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) { for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) { - *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c))); + auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c)); + // Treat any NaN as max_error and immediately return. + // (Note that std::max may ignore NaN arguments.) + if (std::isnan(cur_error)) { + *max_error = cur_error; + return Status::OK(); + } + *max_error = std::max(*max_error, cur_error); } } } @@ -409,6 +416,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x, const Output& y, const TensorShape& y_shape, JAC_T* max_error); INSTANTIATE_GRAD_ERR_TYPE(float, float, float); +INSTANTIATE_GRAD_ERR_TYPE(double, float, double); INSTANTIATE_GRAD_ERR_TYPE(double, double, double); INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float); INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float); diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc index d4f0a7f5ab3716be41e22c02a21aca028f76fb88..8dd762c282eff287bddd49ea6f38b2b8060949b0 100644 --- a/tensorflow/cc/framework/gradient_checker_test.cc +++ b/tensorflow/cc/framework/gradient_checker_test.cc @@ -28,12 +28,14 @@ namespace { using ops::Complex; using ops::Const; +using ops::Div; using ops::MatMul; using ops::Placeholder; using ops::Real; using ops::Split; using ops::Square; using ops::Stack; +using ops::Sub; using ops::Unstack; TEST(GradientCheckerTest, BasicFloat) { @@ -104,6 +106,20 @@ TEST(GradientCheckerTest, Complex64ToFloat) { EXPECT_LT(max_error, 1e-4); } +// When calculating gradients that are undefined, test we get NaN +// as the computed error rather than 0. +TEST(GradientCheckerTest, BasicNan) { + Scope scope = Scope::NewRootScope(); + TensorShape shape({2, 4, 3}); + auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape)); + // y = x/(x-x) should always return NaN + auto y = Div(scope, x, Sub(scope, x, x)); + float max_error; + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); + EXPECT_TRUE(std::isnan(max_error)); +} + TEST(GradientCheckerTest, MatMulGrad) { Scope scope = Scope::NewRootScope(); diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index b353accddcb6db9a07c112de03ead2f02c4ee6a6..e9173227aadbf86eab666e6c17bacacb92888572 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Split", SplitGrad); +Status FillGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = fill(fill_shape, x) + // No gradient returned for the fill_shape argument. + grad_outputs->push_back(NoGradient()); + // The gradient for x (which must be a scalar) is just the sum of + // all the gradients from the shape it fills. + // We use ReduceSum to implement this, which needs an argument providing + // the indices of all the dimensions of the incoming gradient. + // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))]) + auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]), + Const(scope, 1)); + grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims)); + return scope.status(); +} +REGISTER_GRADIENT_OP("Fill", FillGrad); + Status DiagGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index d09275b6487b4212aa35a0476002f2bb587fa210..f41de3dc2098df55fbbb616557f264a4e70db6b6 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) { RunTest({x}, {x_shape}, y.output, {y_shape, y_shape}); } +TEST_F(ArrayGradTest, FillGrad) { + TensorShape x_shape({}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + TensorShape y_shape({2, 5, 3}); + auto y = Fill(scope_, {2, 5, 3}, x); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(ArrayGradTest, DiagGrad) { TensorShape x_shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc new file mode 100644 index 0000000000000000000000000000000000000000..882709e1e2817431a32c453fe0f35f2b2e6c69b0 --- /dev/null +++ b/tensorflow/cc/gradients/image_grad.cc @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/ops/image_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace ops { +namespace { + +Status ResizeNearestNeighborGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + // The internal gradient implementation needs the shape of the input image. + // x_shape = shape(x)[1:3] + // = slice(shape(x), {1}, {3 - 1}) + auto x_shape = Slice(scope, Shape(scope, op.input(0)), {1}, {2}); + grad_outputs->push_back(internal::ResizeNearestNeighborGrad( + scope, grad_inputs[0], x_shape, + internal::ResizeNearestNeighborGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeNearestNeighbor", ResizeNearestNeighborGradHelper); + +Status ResizeBilinearGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + grad_outputs->push_back(internal::ResizeBilinearGrad( + scope, grad_inputs[0], op.input(0), + internal::ResizeBilinearGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeBilinear", ResizeBilinearGradHelper); + +Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + grad_outputs->push_back(internal::ResizeBicubicGrad( + scope, grad_inputs[0], op.input(0), + internal::ResizeBicubicGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper); + +} // anonymous namespace +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e55c7561b030c50bd67bd53fd0d55710085c5d2 --- /dev/null +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/framework/testutil.h" +#include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using ops::Const; +using ops::ResizeBicubic; +using ops::ResizeBilinear; +using ops::ResizeNearestNeighbor; + +class ImageGradTest : public ::testing::Test { + protected: + ImageGradTest() : scope_(Scope::NewRootScope()) {} + + enum OpType { RESIZE_NEAREST, RESIZE_BILINEAR, RESIZE_BICUBIC }; + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const OpType op_type, const Tensor& x_data, const Input& y_shape, + const bool align_corners, Output* x, Output* y) { + *x = Const(scope_, x_data); + switch (op_type) { + case RESIZE_NEAREST: + *y = ResizeNearestNeighbor( + scope_, *x, y_shape, + ResizeNearestNeighbor::AlignCorners(align_corners)); + return; + case RESIZE_BILINEAR: + *y = ResizeBilinear(scope_, *x, y_shape, + ResizeBilinear::AlignCorners(align_corners)); + return; + case RESIZE_BICUBIC: + *y = ResizeBicubic(scope_, *x, y_shape, + ResizeBicubic::AlignCorners(align_corners)); + return; + } + assert(false); + } + + template + void TestResizedShapeForType(const OpType op_type, const bool align_corners) { + TensorShape x_shape({1, 2, 2, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + + ClientSession session(scope_); + std::vector outputs; + TF_ASSERT_OK(session.Run({y}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(outputs[0].shape(), TensorShape({1, 4, 6, 1})); + } + + void TestResizedShape(OpType op_type) { + for (const bool align_corners : {true, false}) { + TestResizedShapeForType(op_type, align_corners); + TestResizedShapeForType(op_type, align_corners); + TestResizedShapeForType(op_type, align_corners); + } + } + + template + void TestResizeToSmallerAndAlign(const OpType op_type, + const bool align_corners) { + TensorShape x_shape({1, 4, 6, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {2, 3}, align_corners, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 2, 3, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + template + void TestResizeToLargerAndAlign(const OpType op_type, + const bool align_corners) { + TensorShape x_shape({1, 2, 3, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + template + void TestResize(OpType op_type) { + for (const bool align_corners : {true, false}) { + TestResizeToSmallerAndAlign(op_type, align_corners); + TestResizeToLargerAndAlign(op_type, align_corners); + } + } + + Scope scope_; +}; + +TEST_F(ImageGradTest, TestNearestNeighbor) { + TestResizedShape(RESIZE_NEAREST); + TestResize(RESIZE_NEAREST); + TestResize(RESIZE_NEAREST); +} + +TEST_F(ImageGradTest, TestBilinear) { + TestResizedShape(RESIZE_BILINEAR); + TestResize(RESIZE_BILINEAR); + // Note that Y_T is always float for this op. We choose + // double for the jacobian to capture the higher precision + // between X_T and Y_T. + TestResize(RESIZE_BILINEAR); +} + +TEST_F(ImageGradTest, TestBicubic) { + TestResizedShape(RESIZE_BICUBIC); + TestResize(RESIZE_BICUBIC); + // Note that Y_T is always float for this op. We choose + // double for the jacobian to capture the higher precision + // between X_T and Y_T. + TestResize(RESIZE_BICUBIC); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index fd7b6fe6625f27bda92e2f56f60908658cdecd7e..1c9bdff5e1295135abe60c282d565c39071fd78a 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -475,11 +475,7 @@ TEST_F(CWiseUnaryGradTest, Tan_Complex) { auto x_fn = [this](const int i) { return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); }; - // TODO(kbsriram) - // Enable when tan kernel supports complex inputs - if (false) { - TestCWiseGrad(TAN, x_fn); - } + TestCWiseGrad(TAN, x_fn); } TEST_F(CWiseUnaryGradTest, Atan) { diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 07807ed2f3e10cb03ec363562087a40c9e86fc5e..3830416159158cca8bfb8422c2959b49fa42406d 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -74,6 +74,54 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir, } } +// Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid +// leaving behind non-GC'ed state. +// +// Detailed motivation behind this approach, from ashankar@: +// +// Each call to Session::Run() that identifies a new subgraph (based on feeds +// and fetches) creates some datastructures that live as long as the session +// (the partitioned graph, associated executors etc.). +// +// A pathological case of this would be if say the initialization op +// (main_op/legacy_init_op) involves the use of a large constant. Then we +// allocate memory for that large constant that will just stick around till the +// session dies. With this Callable mechanism, that memory will be released +// right after ReleaseCallable returns. +// +// However, the resource manager state remains. +Status RunOnce(const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, RunMetadata* run_metadata, + Session* session) { + CallableOptions callable_options; + std::vector feed_tensors; + *callable_options.mutable_run_options() = run_options; + for (const auto& input : inputs) { + const string& name = input.first; + const Tensor& tensor = input.second; + callable_options.add_feed(name); + feed_tensors.push_back(tensor); + } + for (const string& output_tensor_name : output_tensor_names) { + callable_options.add_fetch(output_tensor_name); + } + for (const string& target_node_name : target_node_names) { + callable_options.add_target(target_node_name); + } + + Session::CallableHandle callable_handle; + TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle)); + const Status run_status = session->RunCallable(callable_handle, feed_tensors, + outputs, run_metadata); + // Be sure to call ReleaseCallable() regardless of the outcome of + // RunCallable(). + session->ReleaseCallable(callable_handle).IgnoreError(); + return run_status; +} + bool HasMainOp(const MetaGraphDef& meta_graph_def) { const auto& collection_def_map = meta_graph_def.collection_def(); if (collection_def_map.find(kSavedModelMainOpKey) != @@ -86,10 +134,11 @@ bool HasMainOp(const MetaGraphDef& meta_graph_def) { Status RunMainOp(const RunOptions& run_options, const string& export_dir, const MetaGraphDef& meta_graph_def, const std::vector& asset_file_defs, - Session* session) { - LOG(INFO) << "Running MainOp on SavedModel bundle."; + Session* session, const string& main_op_key) { + LOG(INFO) << "Running MainOp with key " << main_op_key + << " on SavedModel bundle."; const auto& collection_def_map = meta_graph_def.collection_def(); - const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey); + const auto main_op_it = collection_def_map.find(main_op_key); if (main_op_it != collection_def_map.end()) { if (main_op_it->second.node_list().value_size() != 1) { return errors::FailedPrecondition( @@ -99,8 +148,8 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return session->Run(run_options, inputs, {}, {main_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); + return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } return Status::OK(); } @@ -121,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " - "were restored."; + "were restored. File does not exist: " + << variables_index_path; return Status::OK(); } const string variables_path = @@ -137,32 +187,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return session->Run(run_options, inputs, {}, {restore_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); -} - -Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir, - const MetaGraphDef& meta_graph_def, - const std::vector& asset_file_defs, - Session* session) { - LOG(INFO) << "Running LegacyInitOp on SavedModel bundle."; - const auto& collection_def_map = meta_graph_def.collection_def(); - const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey); - if (init_op_it != collection_def_map.end()) { - if (init_op_it->second.node_list().value_size() != 1) { - return errors::FailedPrecondition(strings::StrCat( - "Expected exactly one serving init op in : ", export_dir)); - } - std::vector> inputs; - AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); - RunMetadata run_metadata; - const StringPiece legacy_init_op_name = - init_op_it->second.node_list().value(0); - return session->Run(run_options, inputs, {}, - {legacy_init_op_name.ToString()}, nullptr /* outputs */, - &run_metadata); - } - return Status::OK(); + return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, @@ -204,11 +230,11 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, if (HasMainOp(bundle->meta_graph_def)) { TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, bundle->meta_graph_def, asset_file_defs, - bundle->session.get())); + bundle->session.get(), kSavedModelMainOpKey)); } else { - TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir, - bundle->meta_graph_def, asset_file_defs, - bundle->session.get())); + TF_RETURN_IF_ERROR(RunMainOp( + run_options, export_dir, bundle->meta_graph_def, asset_file_defs, + bundle->session.get(), kSavedModelLegacyInitOpKey)); } return Status::OK(); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 2119c8ec47f941a76e81346ae5d20da78eae11a3..d2f803bd18b38ad5c1a8b5afd70531db117826ea 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -8,28 +8,6 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -# Optional runtime utilities for use by code generated by tfcompile. -cc_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - ], -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], - deps = [ - ":runtime", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - # Don't depend on this directly; this is only used for the benchmark test # generated by tf_library. cc_library( @@ -53,9 +31,9 @@ cc_library( ], deps = [ ":embedded_protocol_buffers", - ":runtime", # needed by codegen to print aligned_buffer_bytes "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -68,6 +46,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu_internal", @@ -237,7 +216,6 @@ test_suite( tests = [ ":benchmark_test", ":codegen_test", - ":runtime_test", ":test_graph_tfadd_test", ":test_graph_tfunknownop2_test", ":test_graph_tfunknownop3_test", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 28070d60dbbe6dd8f930b8e6509cedcf09f94e11..8dbe1e11b7c392cca29fc8792d3cf9f1bf44f1fb 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" -#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -303,10 +303,10 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, const std::vector iarg(arg_sizes.begin(), arg_sizes.end()); const std::vector itemp(temp_sizes.begin(), temp_sizes.end()); const size_t arg_bytes_aligned = - runtime::aligned_buffer_bytes(iarg.data(), iarg.size()); + cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size()); const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size()); const size_t temp_bytes_aligned = - runtime::aligned_buffer_bytes(itemp.data(), itemp.size()); + cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size()); const size_t temp_bytes_total = total_buffer_bytes(itemp.data(), itemp.size()); diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index bbc35da2ef6d14ff0d3570ef2d5cf6743456c674..2b5f97b34cd928d32eb220536342c715d91d45bb 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 5c57fee326ca743dcb8aaae354d261ed4d7f44be..326f73b975aec3a7a6bc7cdc9a92f540ad545ad6 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -16,339 +16,365 @@ tf_library( ) """ -load("//tensorflow:tensorflow.bzl", - "if_android", "tf_cc_test", "tf_copts") - -def tf_library(name, graph, config, - freeze_checkpoint=None, freeze_saver=None, - cpp_class=None, gen_test=True, gen_benchmark=True, - visibility=None, testonly=None, - tfcompile_flags=None, - tfcompile_tool="//tensorflow/compiler/aot:tfcompile", - include_standard_runtime_deps=True, - enable_xla_hlo_profiling=False, deps=None, tags=None): - """Runs tfcompile to compile a TensorFlow graph into executable code. - - Given an invocation of tf_library(name="foo", ...), generates the following - build targets: - foo: A cc_library containing the generated header and computation. - foo_test: A cc_test with simple tests and benchmarks. Only created if - gen_test=True. - foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful - for mobile devices or other platforms that can't compile the - full test libraries. Only created if gen_benchmark=True. - - Args: - name: The name of the build rule. - graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it - is expected to be in the human-readable proto text format, otherwise it is - expected to be in the proto binary format. - config: File containing tensorflow.tf2xla.Config proto. If the file ends - in '.pbtxt' it is expected to be in the human-readable proto text format, - otherwise it is expected to be in the proto binary format. - freeze_checkpoint: If provided, run freeze_graph with this checkpoint to - convert variables into constants. - freeze_saver: If provided, run freeze_graph with this saver, in SaverDef - binary form, to convert variables into constants. - cpp_class: The name of the generated C++ class, wrapping the generated - function. The syntax of this flag is - [[::],...]. This mirrors the C++ syntax - for referring to a class, where multiple namespaces may precede the class - name, separated by double-colons. The class will be generated in the - given namespace(s), or if no namespaces are given, within the global - namespace. - gen_test: If True, also generate a cc_test rule that builds a simple - test and benchmark. - gen_benchmark: If True, also generate a binary with a simple benchmark. - Unlike the output of gen_test, this benchmark can be run on android. - visibility: Bazel build visibility. - testonly: Bazel testonly attribute. - tfcompile_flags: Extra flags to pass to tfcompile to control compilation. - tfcompile_tool: The tfcompile binary. A non-default can be passed to - use a tfcompile built with extra dependencies. - include_standard_runtime_deps: If True, the standard list of kernel/runtime - deps is added to deps. If False, deps must contain the full set of deps - needed by the generated library. - enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, - and emit metadata that lets us pretty-print the gathered profile counters. - deps: a list of deps to include on the build rules for the generated - library, added to the standard deps if standard_runtime_deps is True. - tags: tags to apply to subsidiary build rules. - - The output header is called .h. - """ - if not cpp_class: - fail("cpp_class must be specified") - - tfcompile_graph = graph - if freeze_checkpoint or freeze_saver: - if not freeze_checkpoint: - fail("freeze_checkpoint must be specified when freeze_saver is specified") +load( + "//tensorflow:tensorflow.bzl", + "if_android", + "tf_cc_test", + "tf_copts", +) - freeze_name = "freeze_" + name - freeze_file = freeze_name + ".pb" +def tf_library( + name, + graph, + config, + freeze_checkpoint = None, + freeze_saver = None, + cpp_class = None, + gen_test = True, + gen_benchmark = True, + visibility = None, + testonly = None, + tfcompile_flags = None, + tfcompile_tool = "//tensorflow/compiler/aot:tfcompile", + include_standard_runtime_deps = True, + enable_xla_hlo_profiling = False, + deps = None, + tags = None): + """Runs tfcompile to compile a TensorFlow graph into executable code. - # First run tfcompile to generate the list of out_nodes. - out_nodes_file = "out_nodes_" + freeze_name - native.genrule( - name=("gen_" + out_nodes_file), - srcs=[config], - outs=[out_nodes_file], - cmd=("$(location " + tfcompile_tool + ")" + - " --config=$(location " + config + ")" + - " --dump_fetch_nodes > $@"), - tools=[tfcompile_tool], - # Run tfcompile on the build host, rather than forge, since it's - # typically way faster on the local machine. - local=1, - tags=tags, - ) + Given an invocation of tf_library(name="foo", ...), generates the following + build targets: + foo: A cc_library containing the generated header and + computation. + foo_test: A cc_test with simple tests and benchmarks. Only created if + gen_test=True. + foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, + useful for mobile devices or other platforms that can't + compile the full test libraries. Only created if + gen_benchmark=True. + The output header is called .h. - # Now run freeze_graph to convert variables into constants. - freeze_args = (" --input_graph=$(location " + graph + ")" + - " --checkpoint_version=1" + - " --input_binary=" + str(not graph.endswith(".pbtxt")) + - " --input_checkpoint=$(location " + freeze_checkpoint + ")" + - " --output_graph=$(location " + freeze_file + ")" + - " --output_node_names=$$(<$(location " + out_nodes_file + - "))") - freeze_saver_srcs = [] - if freeze_saver: - freeze_args += " --input_saver=$(location " + freeze_saver + ")" - freeze_saver_srcs += [freeze_saver] - native.genrule( - name=freeze_name, - srcs=[ - graph, - freeze_checkpoint, - out_nodes_file, - ] + freeze_saver_srcs, - outs=[freeze_file], - cmd=("$(location //tensorflow/python/tools:freeze_graph)" + - freeze_args), - tools=["//tensorflow/python/tools:freeze_graph"], - tags=tags, - ) - tfcompile_graph = freeze_file + Args: + name: The name of the build rule. + graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' + it is expected to be in the human-readable proto text format, otherwise + it is expected to be in the proto binary format. + config: File containing tensorflow.tf2xla.Config proto. If the file ends + in '.pbtxt' it is expected to be in the human-readable proto text + format, otherwise it is expected to be in the proto binary format. + freeze_checkpoint: If provided, run freeze_graph with this checkpoint to + convert variables into constants. + freeze_saver: If provided, run freeze_graph with this saver, in SaverDef + binary form, to convert variables into constants. + cpp_class: The name of the generated C++ class, wrapping the generated + function. The syntax of this flag is + [[::],...]. This mirrors the C++ syntax + for referring to a class, where multiple namespaces may precede the + class name, separated by double-colons. The class will be generated in + the given namespace(s), or if no namespaces are given, within the global + namespace. + gen_test: If True, also generate a cc_test rule that builds a simple + test and benchmark. + gen_benchmark: If True, also generate a binary with a simple benchmark. + Unlike the output of gen_test, this benchmark can be run on android. + visibility: Bazel build visibility. + testonly: Bazel testonly attribute. + tfcompile_flags: Extra flags to pass to tfcompile to control compilation. + tfcompile_tool: The tfcompile binary. A non-default can be passed to + use a tfcompile built with extra dependencies. + include_standard_runtime_deps: If True, the standard list of + kernel/runtime deps is added to deps. If False, deps must contain the + full set of deps needed by the generated library. + enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated + program, and emit metadata that lets us pretty-print the gathered + profile counters. + deps: a list of deps to include on the build rules for the generated + library, added to the standard deps if standard_runtime_deps is True. + tags: tags to apply to subsidiary build rules. + """ + if not cpp_class: + fail("cpp_class must be specified") - # Rule that runs tfcompile to produce the header and object file. - header_file = name + ".h" - metadata_object_file = name + "_tfcompile_metadata.o" - function_object_file = name + "_tfcompile_function.o" - ep = ("__" + native.package_name() + "__" + name).replace("/", "_") - if type(tfcompile_flags) == type(""): - flags = tfcompile_flags - else: - flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) - if enable_xla_hlo_profiling: - profiling_flag = "--xla_hlo_profile" - else: - profiling_flag = "" - native.genrule( - name=("gen_" + name), - srcs=[ - tfcompile_graph, - config, - ], - outs=[ - header_file, - metadata_object_file, - function_object_file, - ], - cmd=("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_header=$(@D)/" + header_file + - " --out_metadata_object=$(@D)/" + metadata_object_file + - " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag), - tools=[tfcompile_tool], - visibility=visibility, - testonly=testonly, - # Run tfcompile on the build host since it's typically faster on the local - # machine. - # - # Note that setting the local=1 attribute on a *test target* causes the - # test infrastructure to skip that test. However this is a genrule, not a - # test target, and runs with --genrule_strategy=forced_forge, meaning the - # local=1 attribute is ignored, and the genrule is still run. - # - # https://www.bazel.io/versions/master/docs/be/general.html#genrule - local=1, - tags=tags, - ) + tfcompile_graph = graph + if freeze_checkpoint or freeze_saver: + if not freeze_checkpoint: + fail("freeze_checkpoint must be specified when freeze_saver is " + + "specified") - # Rule that runs tfcompile to produce the SessionModule proto, useful for - # debugging. TODO(b/64813587): Once the SessionModule proto is - # deterministic, move this into the main rule above. - session_module_pb = name + "_session_module.pb" - native.genrule( - name=(name + "_session_module"), - srcs=[ - tfcompile_graph, - config, - ], - outs=[ - session_module_pb, - ], - cmd=("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_session_module=$(@D)/" + session_module_pb + - " " + flags), - tools=[tfcompile_tool], - visibility=visibility, - testonly=testonly, - local=1, - tags=tags, - ) + freeze_name = "freeze_" + name + freeze_file = freeze_name + ".pb" - # The cc_library rule packaging up the header and object file, and needed - # kernel implementations. - need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) - native.cc_library( - name=name, - srcs=[function_object_file, metadata_object_file], - hdrs=[header_file], - visibility=visibility, - testonly=testonly, - deps = [ - # These deps are required by all tf_library targets even if - # include_standard_runtime_deps is False. Without them, the - # generated code will fail to compile. - "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", - "//tensorflow/core:framework_lite", - ] + (need_xla_data_proto and [ - # If we're generating the program shape, we must depend on the proto. - "//tensorflow/compiler/xla:xla_data_proto", - ] or []) + (enable_xla_hlo_profiling and [ - "//tensorflow/compiler/xla/service:hlo_profile_printer_data" - ] or []) + (include_standard_runtime_deps and [ - # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_matmul", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", - "//third_party/eigen3", - ] or []) + (deps or []), - tags=tags, - ) + # First run tfcompile to generate the list of out_nodes. + out_nodes_file = "out_nodes_" + freeze_name + native.genrule( + name = ("gen_" + out_nodes_file), + srcs = [config], + outs = [out_nodes_file], + cmd = ("$(location " + tfcompile_tool + ")" + + " --config=$(location " + config + ")" + + " --dump_fetch_nodes > $@"), + tools = [tfcompile_tool], + # Run tfcompile on the build host, rather than forge, since it's + # typically way faster on the local machine. + local = 1, + tags = tags, + ) - # Variables used for gen_test and gen_benchmark. - no_ns_name = "" - cpp_class_split = cpp_class.rsplit("::", maxsplit=2) - if len(cpp_class_split) == 1: - no_ns_name = cpp_class_split[0] - else: - no_ns_name = cpp_class_split[1] - sed_replace = ( - "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + - "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + - "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ") + # Now run freeze_graph to convert variables into constants. + freeze_args = ( + " --input_graph=$(location " + graph + ")" + + " --checkpoint_version=1" + + " --input_binary=" + str(not graph.endswith(".pbtxt")) + + " --input_checkpoint=$(location " + freeze_checkpoint + ")" + + " --output_graph=$(location " + freeze_file + ")" + + " --output_node_names=$$(<$(location " + out_nodes_file + + "))" + ) + freeze_saver_srcs = [] + if freeze_saver: + freeze_args += " --input_saver=$(location " + freeze_saver + ")" + freeze_saver_srcs += [freeze_saver] + native.genrule( + name = freeze_name, + srcs = [ + graph, + freeze_checkpoint, + out_nodes_file, + ] + freeze_saver_srcs, + outs = [freeze_file], + cmd = ("$(location " + + "//tensorflow/python/tools:freeze_graph)" + + freeze_args), + tools = ["//tensorflow/python/tools:freeze_graph"], + tags = tags, + ) + tfcompile_graph = freeze_file - if gen_test: - test_name = name + "_test" - test_file = test_name + ".cc" - # Rule to rewrite test.cc to produce the test_file. + # Rule that runs tfcompile to produce the header and object file. + header_file = name + ".h" + metadata_object_file = name + "_tfcompile_metadata.o" + function_object_file = name + "_tfcompile_function.o" + ep = ("__" + native.package_name() + "__" + name).replace("/", "_") + if type(tfcompile_flags) == type(""): + flags = tfcompile_flags + else: + flags = " ".join([ + "'" + arg.replace("'", "'\\''") + "'" + for arg in (tfcompile_flags or []) + ]) + if enable_xla_hlo_profiling: + profiling_flag = "--xla_hlo_profile" + else: + profiling_flag = "" native.genrule( - name=("gen_" + test_name), - testonly=1, - srcs=[ - "//tensorflow/compiler/aot:test.cc", + name = ("gen_" + name), + srcs = [ + tfcompile_graph, + config, + ], + outs = [ header_file, + metadata_object_file, + function_object_file, ], - outs=[test_file], - cmd=("sed " + sed_replace + - " $(location //tensorflow/compiler/aot:test.cc) " + - "> $(OUTS)"), - tags=tags, - ) - - # The cc_test rule for the generated code. To ensure that this works - # reliably across build configurations, we must use tf_cc_test instead of - # native.cc_test. This is related to how we build - # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD - # for more details. - tf_cc_test( - name=test_name, - srcs=[test_file], - deps=[ - ":" + name, - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/aot:tf_library_test_main", - "//tensorflow/compiler/xla:executable_run_options", - "//third_party/eigen3", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], - tags=tags, + cmd = ("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + + " " + flags + " " + profiling_flag), + tools = [tfcompile_tool], + visibility = visibility, + testonly = testonly, + # Run tfcompile on the build host since it's typically faster on the + # local machine. + # + # Note that setting the local=1 attribute on a *test target* causes the + # test infrastructure to skip that test. However this is a genrule, not + # a test target, and runs with --genrule_strategy=forced_forge, meaning + # the local=1 attribute is ignored, and the genrule is still run. + # + # https://www.bazel.io/versions/master/docs/be/general.html#genrule + local = 1, + tags = tags, ) - if gen_benchmark: - benchmark_name = name + "_benchmark" - benchmark_file = benchmark_name + ".cc" - benchmark_main = ("//tensorflow/compiler/aot:" + - "benchmark_main.template") - - # Rule to rewrite benchmark.cc to produce the benchmark_file. + # Rule that runs tfcompile to produce the SessionModule proto, useful for + # debugging. TODO(b/64813587): Once the SessionModule proto is + # deterministic, move this into the main rule above. + session_module_pb = name + "_session_module.pb" native.genrule( - name=("gen_" + benchmark_name), - srcs=[ - benchmark_main, - header_file, + name = (name + "_session_module"), + srcs = [ + tfcompile_graph, + config, ], + outs = [ + session_module_pb, + ], + cmd = ("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + flags), + tools = [tfcompile_tool], + visibility = visibility, testonly = testonly, - outs=[benchmark_file], - cmd=("sed " + sed_replace + - " $(location " + benchmark_main + ") " + - "> $(OUTS)"), - tags=tags, + local = 1, + tags = tags, ) - # The cc_benchmark rule for the generated code. This does not need the - # tf_cc_binary since we (by deliberate design) do not depend on - # //tensorflow/core:lib. - # - # Note: to get smaller size on android for comparison, compile with: - # --copt=-fvisibility=hidden - # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN - # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN - native.cc_binary( - name=benchmark_name, - srcs=[benchmark_file], + # The cc_library rule packaging up the header and object file, and needed + # kernel implementations. + need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) + native.cc_library( + name = name, + srcs = [function_object_file, metadata_object_file], + hdrs = [header_file], + visibility = visibility, testonly = testonly, - copts = tf_copts(), - linkopts = if_android(["-pie", "-s"]), - deps=[ - ":" + name, - "//tensorflow/compiler/aot:benchmark", - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/xla:executable_run_options", + deps = [ + # These deps are required by all tf_library targets even if + # include_standard_runtime_deps is False. Without them, the + # generated code will fail to compile. + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/core:framework_lite", + ] + (need_xla_data_proto and [ + # If we're generating the program shape, we must depend on the + # proto. + "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (enable_xla_hlo_profiling and [ + "//tensorflow/compiler/xla/service:hlo_profile_printer_data", + ] or []) + (include_standard_runtime_deps and [ + # TODO(cwhipkey): only depend on kernel code that the model actually + # needed. + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", - ] + if_android([ - "//tensorflow/compiler/aot:benchmark_extra_android", - ]), - tags=tags, + ] or []) + (deps or []), + tags = tags, + ) + + # Variables used for gen_test and gen_benchmark. + cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + if len(cpp_class_split) == 1: + no_ns_name = cpp_class_split[0] + else: + no_ns_name = cpp_class_split[1] + sed_replace = ( + "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + + "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + + "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" " ) + if gen_test: + test_name = name + "_test" + test_file = test_name + ".cc" + + # Rule to rewrite test.cc to produce the test_file. + native.genrule( + name = ("gen_" + test_name), + testonly = 1, + srcs = [ + "//tensorflow/compiler/aot:test.cc", + header_file, + ], + outs = [test_file], + cmd = ( + "sed " + sed_replace + + " $(location //tensorflow/compiler/aot:test.cc) " + + "> $(OUTS)" + ), + tags = tags, + ) + + # The cc_test rule for the generated code. To ensure that this works + # reliably across build configurations, we must use tf_cc_test instead + # of native.cc_test. This is related to how we build + # //tensorflow/core:lib -- see the note in + # tensorflow/core/BUILD for more details. + tf_cc_test( + name = test_name, + srcs = [test_file], + deps = [ + ":" + name, + "//tensorflow/compiler/aot:tf_library_test_main", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], + tags = tags, + ) + + if gen_benchmark: + benchmark_name = name + "_benchmark" + benchmark_file = benchmark_name + ".cc" + benchmark_main = ("//tensorflow/compiler/aot:" + + "benchmark_main.template") + + # Rule to rewrite benchmark.cc to produce the benchmark_file. + native.genrule( + name = ("gen_" + benchmark_name), + srcs = [ + benchmark_main, + header_file, + ], + testonly = testonly, + outs = [benchmark_file], + cmd = ("sed " + sed_replace + + " $(location " + benchmark_main + ") " + + "> $(OUTS)"), + tags = tags, + ) + + # The cc_benchmark rule for the generated code. This does not need the + # tf_cc_binary since we (by deliberate design) do not depend on + # //tensorflow/core:lib. + # + # Note: to get smaller size on android for comparison, compile with: + # --copt=-fvisibility=hidden + # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN + # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN + native.cc_binary( + name = benchmark_name, + srcs = [benchmark_file], + testonly = testonly, + copts = tf_copts(), + linkopts = if_android(["-pie", "-s"]), + deps = [ + ":" + name, + "//tensorflow/compiler/aot:benchmark", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + ] + if_android([ + "//tensorflow/compiler/aot:benchmark_extra_android", + ]), + tags = tags, + ) + def target_llvm_triple(): - """Returns the target LLVM triple to be used for compiling the target.""" - # TODO(toddw): Add target_triple for other targets. For details see: - # http://llvm.org/docs/doxygen/html/Triple_8h_source.html - return select({ - "//tensorflow:android_armeabi": "armv5-none-android", - "//tensorflow:android_arm": "armv7-none-android", - "//tensorflow:android_arm64": "aarch64-none-android", - "//tensorflow:android_x86": "i686-none-android", - "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "//tensorflow:darwin": "x86_64-none-darwin", - "//conditions:default": "x86_64-pc-linux", - }) + """Returns the target LLVM triple to be used for compiling the target.""" + + # TODO(toddw): Add target_triple for other targets. For details see: + # http://llvm.org/docs/doxygen/html/Triple_8h_source.html + return select({ + "//tensorflow:android_armeabi": "armv5-none-android", + "//tensorflow:android_arm": "armv7-none-android", + "//tensorflow:android_arm64": "aarch64-none-android", + "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "//tensorflow:darwin": "x86_64-none-darwin", + "//conditions:default": "x86_64-pc-linux", + }) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 9174a67cc6d110ac21c7bb09346bb1b2dfad0579..15f9ba217f2c2762de36a1e1c0fc7227449bb730 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -166,6 +166,7 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -177,6 +178,7 @@ cc_library( "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:fifo_queue", + "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", @@ -185,6 +187,9 @@ cc_library( "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels/data:generator_dataset_op", + "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:prefetch_dataset_op", ], ) @@ -305,6 +310,7 @@ cc_library( srcs = [ "build_xla_launch_ops_pass.cc", "deadness_analysis.cc", + "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", "mark_for_compilation_pass.cc", ], @@ -377,10 +383,38 @@ tf_cc_test( ) tf_cc_test( - name = "compilation_passes_test", + name = "deadness_analysis_test", size = "small", srcs = [ + "deadness_analysis_internal.h", "deadness_analysis_test.cc", + ], + deps = [ + ":common", + ":compilation_passes", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "compilation_passes_test", + size = "small", + srcs = [ "encapsulate_subgraphs_pass_test.cc", "mark_for_compilation_pass_test.cc", ], diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index d81e5fe9008975c126bcd8e0ea7cef19f1eb1bf3..8aff87e5e620fefd30eeb902209c9bc17540f468 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -151,7 +152,11 @@ class SymbolPredicate : public Predicate { tensor_id_(std::move(tensor_id)), must_be_true_(must_be_true) {} - string ToString() const override { return tensor_id_.ToString(); } + string ToString() const override { + return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + : tensor_id_.ToString(); + } + Kind kind() const override { return Kind::kSymbol; } // If `must_be_true()` is true this SymbolPredicate represents the proposition @@ -348,6 +353,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status Populate(); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; + gtl::FlatMap PredicateMapAsString() const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; @@ -563,4 +569,24 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return Status::OK(); } +gtl::FlatMap +DeadnessAnalysisImpl::PredicateMapAsString() const { + gtl::FlatMap result; + std::vector tensor_ids; + for (const auto& kv_pair : predicate_map_) { + CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); + } + return result; +} + +namespace deadness_analysis_internal { +Status ComputePredicates(const Graph& graph, + PredicateMapTy* out_predicate_map) { + DeadnessAnalysisImpl impl(&graph); + TF_RETURN_IF_ERROR(impl.Populate()); + *out_predicate_map = impl.PredicateMapAsString(); + return Status::OK(); +} +} // namespace deadness_analysis_internal + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..cdef4051108fdc5d063ab592676c7644989155bf --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ + +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +namespace deadness_analysis_internal { + +// Returns a map describing the predicate each Tensor was mapped to. For +// testing purposes only. +using PredicateMapTy = gtl::FlatMap; +Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); +} // namespace deadness_analysis_internal +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 584385cab7665dce9c7c92eab6293436ca22c9b7..6881095b51758d2e0b06c60021bc8c2860ac566e 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -439,5 +440,28 @@ TEST(DeadnessAnalysisTest, RecvVsSwitch) { EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node())); } +TEST(DeadnessAnalysisTest, RecvVsSwitchText) { + // Demonstrates why we need the must_be_true bit on SymbolP. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender", + 0, "receiver"); + Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL); + ops::Switch sw(root.WithOpName("switch"), value, recv); + Output logical_and = + ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + deadness_analysis_internal::PredicateMapTy predicate_map; + TF_ASSERT_OK(deadness_analysis_internal::ComputePredicates(*root.graph(), + &predicate_map)); + + TensorId logical_and_output_0 = {logical_and.node()->name(), + Graph::kControlSlot}; + EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index c5d0e4f8fb61b90eb58d9df398d680b3c5481196..b313d48011b561eaab618692df49d1558c34a77c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -153,6 +153,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 38eb6d830f4d4e889810acd0f928e93d0b22bde8..45d422943c23f59823e6bfbcb355d4b58a6a225e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -462,6 +462,7 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; + VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; std::unique_ptr deadness; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 54a41a4daa790401c797277e7eaab531dd34ac80..7140d47a9421ec73d0144e855b490f89569e6ae9 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -209,7 +209,9 @@ Status XlaCompilationCache::BuildExecutable( argument_layouts[i] = &result.xla_input_shapes[i]; } xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(client_->default_device_ordinal()); + build_options.set_device_ordinal(options.device_ordinal != -1 + ? options.device_ordinal + : client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); build_options.set_device_allocator(options.device_allocator); @@ -256,6 +258,7 @@ Status XlaCompilationCache::CompileImpl( xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options, bool compile_single_op) { + CHECK_NE(executable, nullptr); VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -293,7 +296,7 @@ Status XlaCompilationCache::CompileImpl( // protect the contents of the cache entry. Entry* entry; { - mutex_lock lock(mu_); + mutex_lock lock(compile_cache_mu_); // Find or create a cache entry. std::unique_ptr& e = cache_[signature]; if (!e) { @@ -309,6 +312,8 @@ Status XlaCompilationCache::CompileImpl( if (!entry->compiled) { VLOG(1) << "Compilation cache miss for signature: " << SignatureDebugString(signature); + tensorflow::Env* env = tensorflow::Env::Default(); + const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take // a long time.) std::vector args; @@ -327,18 +332,35 @@ Status XlaCompilationCache::CompileImpl( compile_options ? *compile_options : XlaCompiler::CompileOptions(), function, args, &entry->compilation_result); } - } - *compilation_result = &entry->compilation_result; - if (entry->compilation_status.ok() && executable) { - if (entry->executable == nullptr) { - entry->compilation_status = BuildExecutable( - options, entry->compilation_result, &entry->executable); + TF_RETURN_IF_ERROR(entry->compilation_status); + CHECK_EQ(entry->executable.get(), nullptr); + entry->compilation_status = + BuildExecutable(options, entry->compilation_result, &entry->executable); + + const uint64 compile_end_us = env->NowMicros(); + const uint64 compile_time_us = compile_end_us - compile_start_us; + { + mutex_lock lock(compile_stats_mu_); + auto it = compile_stats_.emplace(function.name(), CompileStats{}).first; + it->second.compile_count++; + it->second.cumulative_compile_time_us += compile_time_us; + VLOG(1) << "compiled " << function.name() << " " + << it->second.compile_count + << " times, compile time: " << compile_time_us + << " us, cumulative: " << it->second.cumulative_compile_time_us + << " us (" + << tensorflow::strings::HumanReadableElapsedTime(compile_time_us / + 1.0e6) + << " / " + << tensorflow::strings::HumanReadableElapsedTime( + it->second.cumulative_compile_time_us / 1.0e6) + << ")"; } - *executable = entry->executable.get(); } - - Status status = entry->compilation_status; - return status; + TF_RETURN_IF_ERROR(entry->compilation_status); + *compilation_result = &entry->compilation_result; + *executable = entry->executable.get(); + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index be1043d8c3fc0573922837e541615114a6d7a1a5..fc5f008f4f52c32d97e680784082d0e7bcb7d8eb 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -150,9 +151,22 @@ class XlaCompilationCache : public ResourceBase { std::unique_ptr executable GUARDED_BY(mu); }; - mutex mu_; - std::unordered_map, Signature::Hash> cache_ - GUARDED_BY(mu_); + mutex compile_cache_mu_; + gtl::FlatMap, Signature::Hash> cache_ + GUARDED_BY(compile_cache_mu_); + + struct CompileStats { + // Number of times the cluster has been (re-)compiled. + int64 compile_count = 0; + + // Cumulative time spent compiling the cluster. + int64 cumulative_compile_time_us = 0; + }; + mutex compile_stats_mu_; + + // Maps cluster names to compilation statistics for said cluster. + gtl::FlatMap compile_stats_ + GUARDED_BY(compile_stats_mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); }; diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index c55eba2f79ddcf10931ea659a64df559cef06ec5..4ddeaebd3e42e96d46857a278451d8c97e49a725 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -211,17 +211,18 @@ XlaDevice::XlaDevice( use_multiple_streams), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), - xla_allocator_(nullptr), platform_(platform), use_multiple_streams_(use_multiple_streams), transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { - VLOG(1) << "Created XLA device " << jit_device_name; + VLOG(1) << "Created XLA device " << jit_device_name << " " << this; } XlaDevice::~XlaDevice() { - if (gpu_device_info_ != nullptr) { - gpu_device_info_->default_context->Unref(); + VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; + mutex_lock lock(mu_); + if (device_context_) { + device_context_->Unref(); } } @@ -237,6 +238,11 @@ xla::LocalClient* XlaDevice::client() const { } Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { + mutex_lock lock(mu_); + return GetAllocatorLocked(attr); +} + +Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) { if (attr.on_host()) { return cpu_allocator(); } @@ -249,83 +255,105 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { return xla_allocator_; } -xla::StatusOr XlaDevice::GetStream() { - if (!stream_) { - xla::Backend* backend = client()->mutable_backend(); - TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_)); - } - return stream_.get(); +Status XlaDevice::EnsureDeviceContextOk() { + mutex_lock lock(mu_); + return GetDeviceContextLocked().status(); } -xla::StatusOr XlaDevice::GetDeviceToHostStream() { - if (!use_multiple_streams_) { - return GetStream(); - } - if (!device_to_host_stream_) { - xla::Backend* backend = client()->mutable_backend(); - TF_ASSIGN_OR_RETURN(device_to_host_stream_, - backend->BorrowStream(device_ordinal_)); +Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, + const string& name, + xla::StreamPool::Ptr* stream, + bool* stream_was_changed) { + if (!(*stream) || !(*stream)->ok()) { + TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_)); + VLOG(1) << "XlaDevice " << this << " new " << name << " " + << (*stream)->DebugStreamPointers(); + *stream_was_changed = true; } - return device_to_host_stream_.get(); + return Status::OK(); } -xla::StatusOr XlaDevice::GetHostToDeviceStream() { - if (!use_multiple_streams_) { - return GetStream(); +xla::StatusOr XlaDevice::GetDeviceContextLocked() { + xla::Backend* backend = client()->mutable_backend(); + + // Ensure all our streams are valid, borrowing new streams if necessary. + bool need_new_device_context = !device_context_; + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, + &need_new_device_context)); + + se::Stream* host_to_device_stream = stream_.get(); + se::Stream* device_to_host_stream = stream_.get(); + if (use_multiple_streams_) { + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", + &host_to_device_stream_, + &need_new_device_context)); + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", + &device_to_host_stream_, + &need_new_device_context)); + host_to_device_stream = host_to_device_stream_.get(); + device_to_host_stream = device_to_host_stream_.get(); } - if (!host_to_device_stream_) { - xla::Backend* backend = client()->mutable_backend(); - TF_ASSIGN_OR_RETURN(host_to_device_stream_, - backend->BorrowStream(device_ordinal_)); + + if (!need_new_device_context) { + return device_context_; } - return host_to_device_stream_.get(); -} -Status XlaDevice::CreateAndSetGpuDeviceInfo() { - if (gpu_device_info_ == nullptr) { - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - // Call GetAllocator for the side-effect of ensuring the allocator - // is created. - GetAllocator({}); - // XlaDevice owns both gpu_device_info_ and - // gpu_device_info_->default_context. - gpu_device_info_ = MakeUnique(); - gpu_device_info_->stream = stream; - gpu_device_info_->default_context = - new XlaDeviceContext(stream, stream, stream, client(), - transfer_as_literal_, shape_representation_fn_); - set_tensorflow_gpu_device_info(gpu_device_info_.get()); + // At this point we know we need a new device context. + // Call GetAllocator for the side-effect of ensuring the allocator is created. + GetAllocatorLocked({}); + if (device_context_) { + device_context_->Unref(); + } + device_context_ = new XlaDeviceContext( + stream_.get(), host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_); + VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " + << device_context_; + + // Create and set a new GpuDeviceInfo, if necessary. + // + // TODO(b/78232898): This isn't thread-safe; there is a race between the call + // to set_tensorflow_gpu_device_info() with ops that call the getter + // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking + // to those methods; see the bug for details. Our only saving grace at the + // moment is that this race doesn't seem to occur in practice. + if (use_gpu_device_info_) { + auto gpu_device_info = MakeUnique(); + gpu_device_info->stream = stream_.get(); + gpu_device_info->default_context = device_context_; + set_tensorflow_gpu_device_info(gpu_device_info.get()); + gpu_device_info_ = std::move(gpu_device_info); + VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo " + << gpu_device_info_.get(); } - return Status::OK(); + return device_context_; +} + +Status XlaDevice::UseGpuDeviceInfo() { + mutex_lock lock(mu_); + use_gpu_device_info_ = true; + return GetDeviceContextLocked().status(); } Status XlaDevice::FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) { VLOG(1) << "XlaDevice::FillContextMap"; - device_context_map->resize(graph->num_node_ids()); - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, - GetDeviceToHostStream()); - TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream, - GetHostToDeviceStream()); + mutex_lock lock(mu_); + TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context, + GetDeviceContextLocked()); - // Call GetAllocator for the side-effect of ensuring the allocator is created. - GetAllocator({}); - auto ctx = new XlaDeviceContext( - stream, host_to_device_stream, device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_); + device_context_map->resize(graph->num_node_ids()); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); - ctx->Ref(); - (*device_context_map)[n->id()] = ctx; + device_context->Ref(); + (*device_context_map)[n->id()] = device_context; } - ctx->Unref(); return Status::OK(); } void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { - VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" + VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); // When Xprof profiling is off (which is the default), constructing the // activity is simple enough that its overhead is negligible. @@ -336,7 +364,7 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { - VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" + VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), op_kernel->IsExpensive()); @@ -358,21 +386,17 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, if (alloc_attrs.on_host()) { *tensor = parsed; } else { - Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); + mutex_lock lock(mu_); + TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context, + GetDeviceContextLocked()); + Allocator* allocator = GetAllocatorLocked(alloc_attrs); + Tensor copy(allocator, parsed.dtype(), parsed.shape()); Notification n; - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, - GetDeviceToHostStream()); - TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream, - GetHostToDeviceStream()); - XlaTransferManager manager(stream, host_to_device_stream, - device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_); - manager.CopyCPUTensorToDevice(&parsed, this, ©, - [&n, &status](const Status& s) { - status = s; - n.Notify(); - }); + device_context->CopyCPUTensorToDevice(&parsed, this, ©, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); *tensor = copy; } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index fccdb143680353ccbe3106bd48aa297980179d55..d8906419b0c406026bb7e10007b2f0a2b4832d01 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -25,10 +25,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/allocator.h" @@ -39,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace tensorflow { @@ -116,62 +119,85 @@ class XlaDevice : public LocalDevice { const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; - Allocator* GetAllocator(AllocatorAttributes attr) override; + Allocator* GetAllocator(AllocatorAttributes attr) override + LOCKS_EXCLUDED(mu_); void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override { return Status::OK(); } Status FillContextMap(const Graph* graph, - DeviceContextMap* device_context_map) override; + DeviceContextMap* device_context_map) override + LOCKS_EXCLUDED(mu_); Status MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + Tensor* tensor) override LOCKS_EXCLUDED(mu_); - xla::LocalClient* client() const; const Metadata& metadata() { return xla_metadata_; } - xla::StatusOr GetStream(); - xla::StatusOr GetHostToDeviceStream(); - xla::StatusOr GetDeviceToHostStream(); - // If not already set, create and set GpuDeviceInfo. - // Not thread-safe - Status CreateAndSetGpuDeviceInfo(); + // Ensures the DeviceContext associated with this XlaDevice is created and + // valid (i.e. all streams are ok). If any state is not valid, a new + // DeviceContext will be created. + // + // TODO(b/111859745): The Eager context needs to call this method to recover + // from failures. + Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_); + + // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra + // information for GPU and TPU devices. + Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); private: + xla::LocalClient* client() const; + Allocator* GetAllocatorLocked(AllocatorAttributes attr) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, + xla::StreamPool::Ptr* stream, + bool* stream_was_changed) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + xla::StatusOr GetDeviceContextLocked() + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - DeviceType jit_device_name_; + const DeviceType jit_device_name_; + // The platform for this device. + se::Platform* const platform_; // Not owned. // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - xla::Backend::StreamPtr stream_; - // If true, only stream_ is valid and all computation and transfers use - // stream_. If false, computation is performed by stream_ and transfers are + xla::StreamPool::Ptr stream_ GUARDED_BY(mu_); + // If false, only stream_ is valid and all computation and transfers use + // stream_. If true, computation is performed by stream_ and transfers are // performed by host_to_device/device_to_host_stream. - bool use_multiple_streams_; + const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. - xla::Backend::StreamPtr host_to_device_stream_; + xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, device to host transfers are performed using this // stream. - xla::Backend::StreamPtr device_to_host_stream_; + xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. - bool transfer_as_literal_; - XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + const bool transfer_as_literal_; + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // The device context accessed by all users of the XlaDevice, set by calls to + // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is + // also filled in to that struct. XlaDeviceContext is a ref-counted object. + XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr; - // If set, holds default device context (that we must Unref) - // and its stream. - std::unique_ptr gpu_device_info_; + // Holds extra information for GPU and TPU devices, e.g. the device context. + bool use_gpu_device_info_ GUARDED_BY(mu_) = false; + std::unique_ptr gpu_device_info_ GUARDED_BY(mu_); }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 8cf198239c84c3720585f53ebc95876ce4396793..0100bf51ed2a66f6d110dacd30bcdf9f48a8f64f 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -101,34 +101,27 @@ Status XlaTransferManager::TransferLiteralToDevice( // Unref the host tensor, and capture the literal shared_ptr too so it goes // out of scope when the lambda completes. host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + return Status::OK(); } void XlaTransferManager::TransferLiteralFromDevice( Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const { + xla::MutableBorrowingLiteral literal; + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal)); + const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_, shaped_buffer, - [=, &shaped_buffer]( - xla::StatusOr > literal_or) { + device_to_host_stream_, shaped_buffer, literal, + [=, &shaped_buffer, &literal](xla::Status status) { ref.Unref(); done([&]() -> Status { - TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString() + VLOG(1) << "Transfer from device as literal: " << literal.ToString() << " " << shaped_buffer.ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - Status status; - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - status = errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } return status; }()); }); diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 6adda327f186a607b4e7371bf4c5071dd86582da..da3e329247e825d4a33a53dc310899d6ba6ce9cf 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,7 +23,11 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/data/generator_dataset_op.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/function_ops.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" @@ -166,7 +170,69 @@ class XlaAssignVariableOp : public AsyncOpKernel { QueueIsClosedOp); \ \ REGISTER_KERNEL_BUILDER( \ - Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T", TYPES) \ + .HostMemory("input"), \ + RetvalOp); \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ + GeneratorDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ + .Device(DEVICE) \ + .HostMemory("buffer_size") \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + PrefetchDatasetOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ + IteratorHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ + MakeIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ + AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ + IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + IteratorToStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); // TODO(phawkins): currently we do not register the QueueEnqueueMany, // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 851b118b0c18cfd752302b8f8dec27dae3e12acd..ef4466f0056ea98adc1ae6774105466af0d14293 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -59,7 +59,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, } // TODO(b/78468222): Uncomment after fixing this bug - // status = device->CreateAndSetGpuDeviceInfo(); + // status = device->UseGpuDeviceInfo(); // if (!status.ok()) { // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, // " device"); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 080bed50e68ba353a5029f5eb959003b51327f4a..ae98b3f0f9d5dac66b9716ad84a9f0371511e9b6 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -673,6 +673,7 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], + shard_count = 5, tags = ["optonly"], deps = [ ":xla_test", @@ -690,11 +691,7 @@ tf_xla_py_test( size = "small", srcs = ["random_ops_test.py"], disabled_backends = [ - # TODO(b/110300529): RngNormal doesn't return values with the expected variance - "cpu", "cpu_ondemand", - # TODO(b/31361304): enable RNG ops on GPU when parallelized. - "gpu", ], deps = [ ":xla_test", @@ -1002,6 +999,7 @@ tf_xla_py_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], + shard_count = 5, # Times out in fastbuild mode. tags = ["optonly"], deps = [ diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 03554d6933aca39b428c6af4be0c78e2c7ccb0c9..0d2e4d029636577adc74784d9a8b3494b94dc67d 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -52,6 +52,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -91,6 +94,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -130,6 +136,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 6ead15da13b86b9d2b4cf2c19e5cf2a90b061b91..422f36d43bf38d26f057c18da716d7e281c286af 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -400,6 +400,21 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertEqual(75, y.numpy()) self.assertEqual(30, dy.numpy()) + def testGradientTapeInDefun(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def f(): + x = constant_op.constant(1.0) + with backprop.GradientTape() as tape: + y = v0 * x + dy = tape.gradient(y, v0) + return dy + + dy = f() + self.assertEqual(1.0, dy.numpy()) + def testSliceInDefun(self): with self.test_scope(): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 8b01ef96db3e8ab58850df234c2e05b764be52ba..bf986ade06b11358552ee92df3169f965ce3f534 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -26,6 +26,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -579,5 +580,140 @@ class ResizeBilinearTest(xla_test.XLATestCase): large_tolerance=True) +class NonMaxSuppressionTest(xla_test.XLATestCase): + + def testNMS128From1024(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + num_boxes = 1024 + boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") + scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") + + max_output_size = 128 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + + def testNMS3From6Boxes(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + # Three boxes are selected based on IOU. + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + + def testNMS3Then2WithScoreThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 2) + self.assertAllClose(indices_tf[:num_valid], [3, 0]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 14c5e7a975e478ca6ceed37c28339b40612801c8..cc0e9b2f98dc2cdb0382140d5172ed51d8ab2b53 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -57,7 +57,8 @@ class RandomOpsTest(xla_test.XLATestCase): def testRandomUniformIsNotConstant(self): def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) + dtype = dtypes.as_dtype(dtype) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=dtype.max) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) @@ -73,6 +74,11 @@ class RandomOpsTest(xla_test.XLATestCase): def testRandomUniformIsInRange(self): for dtype in self._random_types(): + # TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is + # fixed. + if (self.device in ["XLA_GPU", "XLA_CPU" + ]) and (dtype in [dtypes.bfloat16, dtypes.half]): + continue with self.test_session() as sess: with self.test_scope(): x = random_ops.random_uniform( @@ -124,20 +130,29 @@ class RandomOpsTest(xla_test.XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + atol = 2e-4 + if self.device in ["XLA_GPU", "XLA_CPU"]: + atol = 2.2e-4 + self.assertAllClose(actual_mean, expected_mean, atol=atol) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma actual_median = np.median(y) - self.assertAllClose(actual_median, expected_median, atol=8e-4) + self.assertAllClose(actual_median, expected_median, atol=1e-3) expected_variance = sigma**2 * (1 + ( (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) actual_variance = np.var(y) - self.assertAllClose(actual_variance, expected_variance, rtol=3e-4) + rtol = 1e-3 + if self.device in ["XLA_GPU", "XLA_CPU"]: + rtol = 4e-4 + self.assertAllClose(actual_variance, expected_variance, rtol=rtol) def testShuffle1d(self): + # TODO(b/26783907): this test requires the CPU backend to implement sort. + if self.device in ["XLA_CPU"]: + return with self.test_session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 16f293891d56d78885dd515bb7b9899faf0690f7..c0ea242044540b1cef44186880ba3cd92b8849d6 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -101,6 +102,9 @@ class OpTestBuilder { OpTestBuilder& RandomInput(DataType type); OpTestBuilder& RandomInput(DataType type, std::vector dims); + // As RandomInput but the values are unique. + OpTestBuilder& RandomUniqueInput(DataType type, std::vector dims); + // Sets an attribute. template OpTestBuilder& Attr(StringPiece attr_name, T&& value); @@ -126,6 +130,7 @@ class OpTestBuilder { DataType type = DT_INVALID; bool has_dims = false; + bool needs_unique_values = false; std::vector dims; }; @@ -167,6 +172,18 @@ OpTestBuilder& OpTestBuilder::RandomInput(DataType type, return *this; } +OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, + std::vector dims) { + VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString(); + InputDescription input; + input.type = type; + input.has_dims = true; + input.needs_unique_values = true; + input.dims = std::move(dims); + inputs_.push_back(input); + return *this; +} + template OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { AddNodeAttr(attr_name, std::forward(value), &node_def_); @@ -289,7 +306,8 @@ class OpTest : public ::testing::Test { // Returns a tensor filled with random but "reasonable" values from the middle // of the type's range. If the shape is omitted, a random shape is used. // TODO(phawkins): generalize this code to a caller-supplied distribution. - Tensor RandomTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomTensor(DataType dtype, bool needs_unique_values, + gtl::ArraySlice shape); Tensor RandomTensor(DataType dtype); // Like RandomTensor, but uses values >= 0. @@ -432,49 +450,90 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, return dims; } -Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { +Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, + gtl::ArraySlice shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { + gtl::FlatSet already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); - test::FillFn(&tensor, [this, &distribution](int i) -> float { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> float { + float generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_DOUBLE: { + gtl::FlatSet already_generated; std::uniform_real_distribution distribution(-1.0, 1.0); - test::FillFn(&tensor, [this, &distribution](int i) -> double { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> double { + double generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_COMPLEX64: { + gtl::FlatSet> already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); - test::FillFn(&tensor, [this, &distribution](int i) { - return complex64(distribution(generator()), distribution(generator())); + test::FillFn(&tensor, [&](int i) { + complex64 generated; + do { + generated = + complex64(distribution(generator()), distribution(generator())); + } while ( + needs_unique_values && + !already_generated + .insert(std::make_pair(generated.real(), generated.imag())) + .second); + return generated; }); break; } case DT_INT32: { + gtl::FlatSet already_generated; std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); - test::FillFn(&tensor, [this, &distribution](int i) -> int32 { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> int32 { + int32 generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_INT64: { + gtl::FlatSet already_generated; std::uniform_int_distribution distribution(-(1LL << 40), 1LL << 40); - test::FillFn(&tensor, [this, &distribution](int i) -> int64 { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> int64 { + int64 generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_BOOL: { + gtl::FlatSet already_generated; std::bernoulli_distribution distribution; - test::FillFn(&tensor, [this, &distribution](int i) -> bool { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> bool { + bool generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } @@ -485,7 +544,7 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { } Tensor OpTest::RandomTensor(DataType dtype) { - return RandomTensor(dtype, RandomDims()); + return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims()); } Tensor OpTest::RandomNonNegativeTensor(DataType dtype, @@ -761,7 +820,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Ignoring oversize dims."; return kInvalid; } - input_tensors.push_back(RandomTensor(input.type, dims)); + input_tensors.push_back( + RandomTensor(input.type, input.needs_unique_values, dims)); } VLOG(1) << "Input: " << input_tensors.back().DebugString(); } @@ -960,7 +1020,7 @@ TEST_F(OpTest, ArgMax) { std::uniform_int_distribution(-num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMax") - .RandomInput(DT_FLOAT, dims) + .RandomUniqueInput(DT_FLOAT, dims) .Input(test::AsScalar(reduce_dim)) .Attr("T", DT_FLOAT) .Attr("Tidx", DT_INT32) @@ -976,7 +1036,7 @@ TEST_F(OpTest, ArgMin) { std::uniform_int_distribution(-num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMin") - .RandomInput(DT_FLOAT, dims) + .RandomUniqueInput(DT_FLOAT, dims) .Input(test::AsScalar(reduce_dim)) .Attr("T", DT_FLOAT) .Attr("Tidx", DT_INT32) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 5f25ff9002964e94db384d7b01f07cfc4f8938b1..73adb0d243b3b27e6c6ba669b2fd134a5976a2ec 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -361,6 +361,12 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-0.05, 6.05, 5]], dtype=dtype), expected=np.array([[0, 6, 5]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.softmax, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428], + dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softmax, np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), @@ -369,6 +375,14 @@ class UnaryOpsTest(xla_test.XLATestCase): [0.032058604, 0.087144323, 0.23688284, 0.64391428]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.softmax, + np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), + expected=np.array( + [[[0.5, 0.5], [0.5, 0.5]], + [[0.26894142, 0.73105858], [0.26894142, 0.73105858]]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softsign, np.array([[-2, -1, 0, 1, 2]], dtype=dtype), diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 06d977b93c28792704b910c688af510bc650d2a4..85084bb1240cf05f6eabfbea772df113cabe613c 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_control_flow_ops @@ -47,6 +49,34 @@ class XlaDeviceTest(xla_test.XLATestCase): result = sess.run(z, {x: inputs}) self.assertAllCloseAccordingToType(result, inputs + inputs) + def testCopiesOfUnsupportedTypesFailGracefully(self): + """Tests that copies of unsupported types don't crash.""" + test_types = set([ + np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, + np.int64, np.float16, np.float32, np.float16, + dtypes.bfloat16.as_numpy_dtype + ]) + shape = (10, 10) + for unsupported_dtype in test_types - self.all_types: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(unsupported_dtype, shape) + with self.test_scope(): + y, = array_ops.identity_n([x]) + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape) + inputs = inputs.astype(unsupported_dtype) + # Execution should either succeed or raise an InvalidArgumentError, + # but not crash. Even "unsupported types" may succeed here since some + # backends (e.g., the CPU backend) are happy to handle buffers of + # unsupported types, even if they cannot compute with them. + try: + sess.run(z, {x: inputs}) + except errors.InvalidArgumentError: + pass + def testControlTrigger(self): with self.test_session() as sess: with self.test_scope(): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ff002d15b09f5a9dc1a69741672503656f893298..61759fd2764205fab7fce11c4003e84be1be813a 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -81,7 +81,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -91,6 +91,18 @@ cc_library( ], ) +cc_library( + name = "cpu_function_runtime", + srcs = ["cpu_function_runtime.cc"], + hdrs = ["cpu_function_runtime.h"], + deps = [ + # Keep dependencies to a minimum here; this library is used in every AOT + # binary produced by tfcompile. + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + ], +) + cc_library( name = "xla_compiled_cpu_function", srcs = ["xla_compiled_cpu_function.cc"], @@ -99,12 +111,23 @@ cc_library( deps = [ # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. - "//tensorflow/compiler/aot:runtime", + ":cpu_function_runtime", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], ) +tf_cc_test( + name = "cpu_function_runtime_test", + srcs = ["cpu_function_runtime_test.cc"], + deps = [ + ":cpu_function_runtime", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "xla_jit_compiled_cpu_function", srcs = ["xla_jit_compiled_cpu_function.cc"], @@ -119,6 +142,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service/cpu:cpu_executable", "//tensorflow/core:lib", @@ -139,14 +163,12 @@ cc_library( "xla_op_registry.cc", "xla_resource.cc", "xla_cpu_backend.cc", - "legacy_flags/backend_registration_flags.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", ]), hdrs = [ "const_analysis.h", "graph_compiler.h", - "legacy_flags/backend_registration_flags.h", "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", @@ -172,16 +194,14 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -294,6 +314,7 @@ tf_cc_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc similarity index 83% rename from tensorflow/compiler/aot/runtime.cc rename to tensorflow/compiler/tf2xla/cpu_function_runtime.cc index 5e74079fc158379b8977ada6412141e39142c3d3..2ffad2af8cfe621f0cbbdd8a9484ef2dfdf1b129 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,22 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/runtime.h" - -#include +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/core/platform/dynamic_annotations.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { - namespace { - // Inline memory allocation routines here, because depending on '//base' brings // in libraries which use c++ streams, which adds considerable code size on // android. -inline void* aligned_malloc(size_t size, int minimum_alignment) { +void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); #elif defined(_WIN32) @@ -47,7 +41,7 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { +void aligned_free(void* aligned_memory) { #if defined(_WIN32) _aligned_free(aligned_memory); #else @@ -58,13 +52,13 @@ inline void aligned_free(void* aligned_memory) { size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; } - } // namespace -size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) { +namespace cpu_function_runtime { +size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) { size_t total = 0; for (size_t i = 0; i < n; ++i) { - if (sizes[i] != -1) { + if (sizes[i] > 0) { total += align_to(sizes[i], kAlign); } } @@ -73,7 +67,7 @@ size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) { void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, bool annotate_initialized) { - const size_t total = aligned_buffer_bytes(sizes, n); + const size_t total = AlignedBufferBytes(sizes, n); void* contiguous = nullptr; if (total > 0) { contiguous = aligned_malloc(total, kAlign); @@ -85,7 +79,9 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, } uintptr_t pos = reinterpret_cast(contiguous); for (size_t i = 0; i < n; ++i) { - if (sizes[i] == -1) { + if (sizes[i] < 0) { + // bufs[i] is either a constant, an entry parameter or a thread local + // allocation. bufs[i] = nullptr; } else { bufs[i] = reinterpret_cast(pos); @@ -100,7 +96,5 @@ void FreeContiguous(void* contiguous) { aligned_free(contiguous); } } - -} // namespace runtime -} // namespace tfcompile +} // namespace cpu_function_runtime } // namespace tensorflow diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h similarity index 70% rename from tensorflow/compiler/aot/runtime.h rename to tensorflow/compiler/tf2xla/cpu_function_runtime.h index d1a669ceb17b9fd71d26e978035283f8824b0376..c7b4559c65731d1c4f4ea41e8be173ba89fe359c 100644 --- a/tensorflow/compiler/aot/runtime.h +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,25 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file contains utilities to make it easier to invoke functions generated -// by tfcompile. Usage of these utilities is optional. - -#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_ -#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_ +#ifndef TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ +#define TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ #include "tensorflow/core/platform/types.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { +namespace cpu_function_runtime { // Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 64; +constexpr size_t kAlign = 64; -// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 -// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign -// byte boundaries. -size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n); +// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1 +// values. There are `n` entries in `sizes`. Each buffer is aligned to +// kAlign byte boundaries. +size_t AlignedBufferBytes(const intptr_t* sizes, size_t n); // MallocContiguousBuffers allocates buffers for use by the entry point // generated by tfcompile. `sizes` is an array of byte sizes for each buffer, @@ -41,8 +37,8 @@ size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n); // temporary buffers. // // A single contiguous block of memory is allocated, and portions of it are -// parceled out into `bufs`, which must have space for `n` entries. Returns the -// head of the allocated contiguous block, which should be passed to +// parceled out into `bufs`, which must have space for `n` entries. Returns +// the head of the allocated contiguous block, which should be passed to // FreeContiguous when the buffers are no longer in use. void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, bool annotate_initialized); @@ -50,9 +46,7 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, // FreeContiguous frees the contiguous block of memory allocated by // MallocContiguousBuffers. void FreeContiguous(void* contiguous); - -} // namespace runtime -} // namespace tfcompile +} // namespace cpu_function_runtime } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_ +#endif // TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc similarity index 71% rename from tensorflow/compiler/aot/runtime_test.cc rename to tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc index 06ec623eb2dce5f8dc7156fb7e7b9ad57d90c8ee..f4f27a156261ea6872777cef76ecaf7dd7eebe0d 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc @@ -13,39 +13,37 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { namespace { -TEST(Runtime, AlignmentValue) { +TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 // So any value that we choose must abide by that constraint as well. - EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment); + EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); } -TEST(Runtime, AlignedBufferBytes) { - EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0); +TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) { + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0); static constexpr intptr_t sizesA[1] = {-1}; - EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -56,48 +54,49 @@ void* add_ptr(void* base, uintptr_t delta) { // expected nullptrs, and write to each byte of allocated memory. We rely on // the leak checker to tell us if there's an inconsistency between malloc and // free. We also check the contiguous property. -TEST(Runtime, MallocFreeContiguousBuffers) { +TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test empty sizes. - void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false); + void* base = + cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false); EXPECT_EQ(base, nullptr); - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with 0 sum. static constexpr intptr_t sizesA[1] = {-1}; void* bufA[1]; - base = MallocContiguousBuffers(sizesA, 1, bufA, false); + base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false); EXPECT_EQ(base, nullptr); EXPECT_EQ(bufA[0], nullptr); - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum. static constexpr intptr_t sizesB[1] = {3}; void* bufB[1]; - base = MallocContiguousBuffers(sizesB, 1, bufB, false); + base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufB[0], add_ptr(base, 0)); char* bufB0_bytes = static_cast(bufB[0]); bufB0_bytes[0] = 'A'; bufB0_bytes[1] = 'B'; bufB0_bytes[2] = 'C'; - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum, and annotate_initialized. static constexpr intptr_t sizesC[1] = {3}; void* bufC[1]; - base = MallocContiguousBuffers(sizesC, 1, bufC, true); + base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true); EXPECT_NE(base, nullptr); EXPECT_EQ(bufC[0], add_ptr(base, 0)); char* bufC0_bytes = static_cast(bufC[0]); bufC0_bytes[0] = 'A'; bufC0_bytes[1] = 'B'; bufC0_bytes[2] = 'C'; - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test mixed sizes. static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; void* bufD[7]; - base = MallocContiguousBuffers(sizesD, 7, bufD, false); + base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); @@ -115,10 +114,8 @@ TEST(Runtime, MallocFreeContiguousBuffers) { } } } - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); } } // namespace -} // namespace runtime -} // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 03603ee9baefd1d20d220faf63c9c1c427ebdf31..24616c01c7e54b2e8662457ca6af23a0bc563e08 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -33,7 +33,7 @@ struct NameCounts { std::unordered_map counts; }; -string MakeUniquePath(string name) { +string MakeUniqueFilename(string name) { static NameCounts& instance = *new NameCounts; // Remove illegal characters from `name`. @@ -50,26 +50,41 @@ string MakeUniquePath(string name) { count = instance.counts[name]++; } - legacy_flags::DumpGraphFlags* flags = legacy_flags::GetDumpGraphFlags(); - string path = strings::StrCat(flags->tf_dump_graph_prefix, "/", name); + string filename = name; if (count > 0) { - strings::StrAppend(&path, "_", count); + strings::StrAppend(&filename, "_", count); } - strings::StrAppend(&path, ".pbtxt"); - return path; + strings::StrAppend(&filename, ".pbtxt"); + return filename; +} + +string WriteTextProtoToUniqueFile( + Env* env, const string& name, const char* proto_type, + const ::tensorflow::protobuf::Message& proto) { + const string& dirname = + legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix; + Status status = env->RecursivelyCreateDir(dirname); + if (!status.ok()) { + LOG(WARNING) << "Failed to create " << dirname << " for dumping " + << proto_type << ": " << status; + return "(unavailable)"; + } + string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name)); + status = WriteTextProto(Env::Default(), filepath, proto); + if (!status.ok()) { + LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath + << " : " << status; + return "(unavailable)"; + } + LOG(INFO) << "Dumped " << proto_type << " to " << filepath; + return filepath; } } // anonymous namespace string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - string path = MakeUniquePath(name); - Status status = WriteTextProto(Env::Default(), path, graph_def); - if (!status.ok()) { - VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; - path.clear(); - path = "(unavailable)"; - } - return path; + return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", + graph_def); } string DumpGraphToFile(const string& name, Graph const& graph, @@ -83,15 +98,7 @@ string DumpGraphToFile(const string& name, Graph const& graph, } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - string path = MakeUniquePath(name); - Status status = WriteTextProto(Env::Default(), path, fdef); - if (!status.ok()) { - VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " - << status; - path.clear(); - path = "(unavailable)"; - } - return path; + return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 6cc95149a16a59fce8486c5d103ad09e3e262765..0904778f97c95628c81054cd4bc2ff32ff440a33 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -177,8 +177,8 @@ Status CheckNoCycleContains(const Node* node, const int num_nodes) { visited[current_node->id()] = true; for (const Edge* out : current_node->out_edges()) { if (out->dst() == node) { - return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(", - node->def().op(), ") feeds into itself."); + return errors::Internal("Detected a cycle: ", FormatNodeForError(*node), + "(", node->def().op(), ") feeds into itself."); } else if (!visited[out->dst()->id()]) { ready.push_back(out->dst()); } @@ -324,7 +324,7 @@ Status AddMissingFunctionDef(const FunctionDef& fdef, if (library->Find(node.op())) { continue; } - // The function refered by 'SymbolicGradient' node is specified in its + // The function referred by 'SymbolicGradient' node is specified in its // attribute 'f'. if (node.op() == FunctionLibraryDefinition::kGradientOp) { const AttrValue* attr = @@ -437,22 +437,24 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, continue; } if (enter_merge != nullptr) { - return errors::Internal( - "Enter node for loop-varying argument ", arg.enter->name(), - " has multiple successors: ", enter_merge->dst()->name(), " and ", - e->dst()->name()); + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has multiple successors: ", + FormatNodeForError(*enter_merge->dst()), + " and ", FormatNodeForError(*e->dst())); } enter_merge = e; } if (enter_merge == nullptr) { return errors::Internal("Enter node for loop-varying argument ", - arg.enter->name(), " has zero successors"); + FormatNodeForError(*arg.enter), + " has zero successors"); } arg.merge = enter_merge->dst(); if (!IsMerge(arg.merge)) { return errors::InvalidArgument( "Successor of Enter node for loop-varying argument ", - arg.merge->name(), + FormatNodeForError(*arg.merge), " is not a Merge node; got: ", arg.merge->type_string()); } @@ -462,7 +464,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, return errors::InvalidArgument( "Unexpected number of inputs to Merge node for loop-varying " "argument ", - arg.merge->name(), "; expected 2, got ", + FormatNodeForError(*arg.merge), "; expected 2, got ", arg.merge->input_types().size()); } TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), @@ -470,7 +472,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, if (!IsNextIteration(arg.next_iteration)) { return errors::InvalidArgument( "Expected NextIteration node as input to Merge node; got node ", - arg.next_iteration->name(), " with kind ", + FormatNodeForError(*arg.next_iteration), " with kind ", arg.next_iteration->type_string()); } @@ -481,14 +483,14 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, switches.find(edge->dst()) != switches.end()) { if (arg.switch_node != nullptr) { return errors::InvalidArgument("Duplicate Switch successors to ", - arg.merge->name()); + FormatNodeForError(*arg.merge)); } arg.switch_node = edge->dst(); } } if (arg.switch_node == nullptr) { return errors::InvalidArgument("Missing Switch successor to ", - arg.merge->name()); + FormatNodeForError(*arg.merge)); } // Update the device on the Identity outputs of the switch to match their @@ -516,14 +518,15 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, possible_exit.pop_front(); if (IsExit(edge->dst())) { if (arg.exit != nullptr) { - return errors::InvalidArgument("Duplicate Exit successors to ", - arg.switch_node->name()); + return errors::InvalidArgument( + "Duplicate Exit successors to ", + FormatNodeForError(*arg.switch_node)); } arg.exit = edge->dst(); } else { if (!IsIdentity(edge->dst())) { return errors::Unimplemented("General graph between switch (", - arg.switch_node->name(), + FormatNodeForError(*arg.switch_node), ") and exit node of frame ", frame->name, " not supported yet."); } @@ -1470,7 +1473,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, if (!unreachable_nodes.empty()) { return errors::InvalidArgument( "The following nodes are unreachable from the source in the graph: ", - tensorflow::str_util::Join(unreachable_nodes, ", ")); + errors::FormatNodeNamesForError(unreachable_nodes)); } // Builds Frames, indexed by name. diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index aae2f8ee5acd6249f8b6002d94c877f18064f936..ccf249b35d66861888ad5e5e904b5f63b8ac50a1 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -1064,7 +1064,10 @@ TEST(FunctionalizeControlFlow, Cycle) { // less -> XlaIf <--> identity. Status status = FunctionalizeControlFlow(graph.get(), &library); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle")) + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detected a cycle")) + << status.error_message(); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "{{node cond/Less_5_If}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index e1cea03865ce9978e634429b5ce41fe8b245a575..e4fdf0a6186eb69a2e3413838c91616b992ef2d6 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d88a34dfd9bf7e9a5a9e65004b87833dd35de668..3bfe74521fb30639cb08495c729cbaf6232dd996 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -123,12 +123,15 @@ tf_kernel_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -164,7 +167,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -180,7 +184,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -217,8 +221,8 @@ tf_kernel_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:argmax_op", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index e33532828040123243f839ab1aa655b4bbc72520..41a453da80dec6b6f57a4d222e2c33ef6b786a10 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 26fc1620a4f032b3af28de6e3a5af0e965e82341..276d744c096f8996c774964204feaa3762bdb844 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -65,6 +65,6 @@ class XlaArgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); }; -REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp); +REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index c4af79281d2162b1dbfb0a7881720892f4bc49d2..b3ad0aea84eef601de08909f760699b8700d28f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 26130fd9e7fce75c6d2a5a53cfc85842cf762b35..48f2a005ab16651fe29d0f6f9d881f95693da461 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index e9b2c0b16d39cb3b747c0316621fb01de709b12e..41f540506ba41fbe7f91393e7b8e26a89e72ef0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index d6d4ae89376b67c14af8ef4f3a608fcc83b6fb59..2c328102e0bd84709707f102272691b6aec9a577 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index efbdb76eaaf78904fe783a018940b1b096ec39bd..5078f8662bd397eaa51274ec816c130b8ced92cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 62eebf762b3e063da8ec456cc4726d3cc9b77d1d..8cc2479dd555380da7500abe6b2aca380110333b 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 1784e712b56145bbdff5f1daa2e031b65d0774b6..e7fef77edcba0ea5a521956a704225ac4f7fcb22 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index 4e6d33304c4ae08a0fd1e0a8373267a527087528..547fe48046e8c934e3bc14d02c8448e107c1a406 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e3a32a5c0e2f93237c8c7ebeea3668b5d1ab6c23..f4106051043859a6786705009d76b02a64cd3ff1 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index f4360d8c3f6fc4007c31fdcfd7f7634de15c76d4..da8cf3fc6fa694f592280f8c249d317827d9cd09 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 48ac4867edcef97be001a24f42f6a35225d466c9..5da7972397b32fb4a2f216913e065c04131a3773 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 500a564f3f0489a42dbc9d5b70ae7708a7a43973..db579a5b35d69deb3dca578e31c1b54fada76342 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 9ff3e0222831cb4339943966810eeae451e47a2c..ef1015552d181a183d412f9c269dd5ec608b388f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 4f92dbc8740b697322424058530b8477c35d809a..a5b870f8dbf70bcee331992345d63fd5d986bdca 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index f3149200250935629a6e4bf67bff0c048135ce3e..12b0e38288e8f222ed506a75ec2575f27141c859 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 22cda27567a58f17ca92803d4eccfc1f29f0b8b8..ed44ad218b6dc073583ec339da082b6881ad672d 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 3b86ea34c9e7d943eb9c7de222e0a2be049ebc68..a3389d5b905bf3ee15744ab4fcee193d312e2ae0 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 958231505b50431b9bb267b0a3cc5ed56e3aeb21..cb73053666d4c32bc0a2ef19b174aee1a29f101e 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 81f42e504e4b6f813a29769719a7a7fb5d99b9c5..5fdb1d972c55efb876972d3f472b53a1f7cde1c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 65d42a302fca48c7b5f88813f80e975823f63ddf..c68b0bfd7961892294c2931e5c4c44de534a7740 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 2fd1a34741e1c7235397f9a69dd8444b4679fa22..cdba6680dee3fade5bdf0c453ed672b653072b0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index b2b00e51e3b00fa93c258af489cf0f4a3e6e764b..80bcef966360ec9a1ca63a02741108ce41b31846 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 95faa1d058f4c0d3fa802b157c6daba1e1adaf41..54b21a278229024e3e54e9135548be6b69b077e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 5f041be5df226ed996b21844c0cf92b6dfac005c..35de96e0aab847fa39ef26d5f3052c392062fd7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index d898e43b858bac706d524c7c271f48b1b5fa258f..92346283c31dfe1d638526ac4b26ef762cd7fd14 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index e2160feba00a7272635207ebcb53670cacf34620..462e0e439583e6e2a622e44eecd4874ad5da7562 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { @@ -247,6 +247,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); +REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index cb4caf7bcb4caaa1bf7e0e79e52bb966a8838db3..33a73fe5fdf403e513be085dd7bcea3255277b4a 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -17,7 +17,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -311,5 +316,150 @@ class AdjustHueOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); +class NonMaxSuppressionOp : public XlaOpKernel { + public: + explicit NonMaxSuppressionOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", + &pad_to_max_output_size_)); + } + + void Compile(XlaOpKernelContext* context) override { + // TODO(b/111646731): Improve scalability of this op, using blocking. + int num_boxes_dim = 0; + int coords_dim = 1; + const TensorShape& boxes_shape = context->InputShape("boxes"); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), + errors::InvalidArgument("boxes must be 2-D, currently: ", + boxes_shape.DebugString())); + const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim); + OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4, + errors::InvalidArgument("boxes must have 4 columns", + boxes_shape.DebugString())); + const TensorShape& scores_shape = context->InputShape("scores"); + OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape), + errors::InvalidArgument("scores must be 1-D, currently: ", + scores_shape.DebugString())); + OP_REQUIRES( + context, scores_shape.dim_size(0) == num_boxes, + errors::InvalidArgument("scores size must equal number of boxes", + scores_shape.DebugString())); + OP_REQUIRES(context, pad_to_max_output_size_, + errors::InvalidArgument( + "XLA compilation requires pad_to_max_output_size == True")); + + xla::XlaOp boxes = context->Input("boxes"); + xla::XlaOp scores = context->Input("scores"); + int64 output_size; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); + OP_REQUIRES( + context, output_size >= 0, + errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + xla::XlaOp score_thresh = context->Input("score_threshold"); + xla::XlaOp iou_thresh = context->Input("iou_threshold"); + + xla::XlaBuilder* const builder = context->builder(); + + // Choose a more convenient layout. + xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0}); + coords_dim = 0; + num_boxes_dim = 1; + + // Shapes are henceforth [1, num_boxes]. + xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t, + /*start_index=*/0, + /*limit_index=*/1, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t, + /*start_index=*/1, + /*limit_index=*/2, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t, + /*start_index=*/2, + /*limit_index=*/3, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t, + /*start_index=*/3, + /*limit_index=*/4, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp y1 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1); + xla::XlaOp y2 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0); + xla::XlaOp x1 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1); + xla::XlaOp x2 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0); + xla::XlaOp area = (y2 - y1) * (x2 - x1); + + // Transpose the 1xN tensors, instead of the NxN tensors. + xla::XlaOp y1_t = xla::Transpose(y1, {1, 0}); + xla::XlaOp y2_t = xla::Transpose(y2, {1, 0}); + xla::XlaOp x1_t = xla::Transpose(x1, {1, 0}); + xla::XlaOp x2_t = xla::Transpose(x2, {1, 0}); + xla::XlaOp area_t = xla::Transpose(area, {1, 0}); + + // Shapes are henceforth [num_boxes, num_boxes]. + xla::XlaOp i_xmin = xla::Max(x1, x1_t); + xla::XlaOp i_ymin = xla::Max(y1, y1_t); + xla::XlaOp i_xmax = xla::Min(x2, x2_t); + xla::XlaOp i_ymax = xla::Min(y2, y2_t); + auto square_zero = xla::ZerosLike(i_xmin); + + xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * + xla::Max(i_ymax - i_ymin, square_zero); + xla::XlaOp u_area = area + area_t - i_area; + xla::XlaOp iou = i_area / u_area; + + xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); + xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1}); + xla::XlaOp score_cmp_mask = + xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0})); + xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask); + + // Shapes are [num_boxes] after the reduce. + xla::XlaOp included_iou = xla::Not(xla::Reduce( + suppress, + /*init_value=*/xla::ConstantR0(builder, false), + /*computation=*/CreateScalarOrComputation(xla::PRED, builder), + /*dimensions_to_reduce=*/{0})); + xla::XlaOp included_score = + xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); + xla::XlaOp included = xla::And(included_iou, included_score); + xla::XlaOp neg_inf = + xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); + xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); + + xla::XlaOp ones_included = xla::Select( + included, + xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), + xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); + + // num_valid is scalar. + xla::XlaOp num_valid = xla::Reduce( + ones_included, + /*init_value=*/xla::ConstantR0(builder, 0), + /*computation=*/CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + + xla::XlaOp output_tuple = TopK(scores_included, output_size); + xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); + + context->SetOutput(0, selected_indices); + context->SetOutput(1, num_valid); + } + + private: + bool pad_to_max_output_size_; +}; + +REGISTER_XLA_OP( + Name("NonMaxSuppressionV4").CompileTimeConstInput("max_output_size"), + NonMaxSuppressionOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index d6bf92fb3df8d38909df99e11c85ede4fac2bf81..8d75624e74028ea083c3facc4f9578ec14c50e6d 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/math/math_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 9e64711051d31107db1bf6f1966f9ed6f5630c34..f028e361bccd51de0bd69a1d2227c7afaed53455 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 2fb072f827906d40dcf410f0312394c4f568a28d..a11bbe918f7f8eb050aaa40d4344f9cc9e9a10a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index dc934543cb2f94fbe1e8f1f865156eb082d6a127..87ee2d3aede50eb24e65570f106d49030e1d4236 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index aa45b025512cdeb27e3b0cabb3f194a58c6f86f9..6440770c29894c951f010f6c1deb929f4fe79bbf 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index e06c87db7adb1840606208fe15cd68a3ca4d137a..8dfd7de591c4a3c4768dd60b41e03d294ad49397 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index e2ab4b83cfb45b2f9a7f3aba2d2a927d10ad8b85..c0ca881ff82cee04e0c5e35f9a2d5732fabdd8a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 529959dbd90b05f8860360f70e087ef225150600..eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/mirror_pad_mode.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index 3aed47de2603f3e187ad515d4db3f884da4c6cc8..a9b519d8928cc2807831fd6b4f12e60b7d58ea55 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 89fd610bc63349d008836c3c4e6ec8927c232a54..e5937b56c17d01892928b073da09f38941ea1bbb 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 12d9cb9bac6b98c8f4c3edc3f6c661acf4466f98..d4d180aff806f12875f0e43f111ee090f6607ef6 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -21,7 +21,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -70,59 +72,53 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } - // Method that builds an initial value to use in reductions. - virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; - - // The reduction operation to apply to each window. - virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; - - // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) = 0; - - void Compile(XlaOpKernelContext* ctx) override { - std::vector ksize = ksize_; - std::vector stride = stride_; - if (ctx->num_inputs() != 1) { - const TensorShape ksize_shape = ctx->InputShape(1); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), - errors::InvalidArgument("ksize must be a vector, not shape ", - ksize_shape.DebugString())); - OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window ksize field must " - "specify ", - num_dims(), " dimensions")); - ksize.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); - - const TensorShape stride_shape = ctx->InputShape(2); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), - errors::InvalidArgument("stride must be a vector, not shape ", - stride_shape.DebugString())); - OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window stride field must " - "specify ", - num_dims(), " dimensions")); - stride.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); + protected: + xla::StatusOr> GetKernelSize(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return ksize_; } - const TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("Input to ", type_string(), - " operator must have ", num_dims(), - " dimensions")); + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(ksize_shape)) { + return errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString()); + } + if (ksize_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window ksize field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector ksize; + auto status = ctx->ConstantInputAsIntVector(1, &ksize); + if (!status.ok()) { + return status; + } + return ksize; + } - xla::XlaBuilder* const b = ctx->builder(); - auto input = - XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, - stride, padding_); - auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - ctx->SetOutput(0, - PostProcessOutput(ctx, pooled, input_type(0), input_shape)); + xla::StatusOr> GetStride(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return stride_; + } + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(stride_shape)) { + return errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString()); + } + if (stride_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window stride field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector stride; + auto status = ctx->ConstantInputAsIntVector(2, &stride); + if (!status.ok()) { + return status; + } + return stride; } protected: @@ -135,24 +131,48 @@ class PoolingOp : public XlaOpKernel { xla::PrimitiveType xla_reduction_type_; }; +// Converts the tensor data format to the one required by the XLA pooling +// library. +xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, + int num_spatial_dims) { + int num_dims = num_spatial_dims + 2; + int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); + gtl::InlinedVector spatial_dimensions(num_spatial_dims); + for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { + spatial_dimensions[spatial_dim] = + GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); + } + return xla::TensorFormat(/*batch_dimension=*/batch_dimension, + /*feature_dimension=*/feature_dimension, + /*spatial_dimensions=*/spatial_dimensions); +} + class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ctx->input_type(0)) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return xla::MinValue(b, xla_reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateMax(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return output; + auto pooling = + xla::MaxPool(ctx->Input(0), ksize, stride, padding_, + XlaTensorFormat(data_format_, input_shape.dims() - 2)); + ctx->SetOutput(0, pooling); } }; @@ -179,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Common computation shared between AvgPool and AvgPoolGrad. Divide each -// element of an image by the count of elements that contributed to that -// element during pooling. +// Divide each element of an image by the count of elements that contributed to +// that element during pooling. static xla::XlaOp AvgPoolDivideByCount( XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, const TensorShape& input_shape, xla::Padding padding, @@ -240,20 +259,34 @@ class AvgPoolOp : public PoolingOp { /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return xla::Zero(b, xla_reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateAdd(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, - ksize_, stride_, num_spatial_dims_, - data_format_); + auto xla_data_format = + XlaTensorFormat(data_format_, input_shape.dims() - 2); + auto spatial_padding = MakeSpatialPadding( + input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format); + + // Convert the input to the reduction type. + auto converted_input = + ConvertElementType(ctx->Input(0), xla_reduction_type_); + auto pooling = + xla::AvgPool(converted_input, ksize, stride, spatial_padding, + xla_data_format, padding_ == xla::Padding::kValid); + // Convert the pooling result back to the input type before returning it. + ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index e88221e4f400abeec59d85c1539d4f70bf515d3c..6f4ed496a1774dde68dd9d5fbd37995d615b678c 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -19,7 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 607cad798a98cfa0c6161a8154001926384e724e..2da9340625db08b14b78340c471f096baf15689d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 76bd1e62aa1efd85d6ed489b9a6d22a2bacf2a8b..b11a4ce36da9907ce8fe377c075023a4540797fa 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,7 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index be7f2bce8cb249aa51ca091e02da7dffc7d06743..0d260fa8fcaa513d7854c1e9215952404d555c70 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 8333f9b288e27efe9497306f031980c9eec7c99c..466e79828d111ee7cadcf713703e8f252c63e62c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index ed1d1c661091fd9bc443336626e593e728b82830..b52f0a0ab6290f2019bb58120be5c2364ec15bb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -19,7 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index f4b804e54677c7226d8d3429c9e8c27686d19ccf..d35777ccb1271ec6a7c9972c714d06b2415d9c34 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 354fec9be75e9559b204e2afd6ee08dfc7cea872..121750a82a8c5cbe940068555ad273b7e0d22dfc 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 5be70a4ded31a988cb77cdabe3fc8a041bc3ad16..64900e4709fd3e16d21096b0cfff8922906cb0d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -104,7 +104,7 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval"), RetvalOp); +REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index ec15b4cc7a523d5b8d4287bbe3321433f315063b..d962ef4a5f53470838643541f8a1e693d2f4011c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index c810456f94322acfccae18d78efa861eede4648c..03a50ef8a059e5a005c4cc2e5e98acedfea8619a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 27ab3e1bf5b81a901e64a242e5eb343591481efe..ab094d7dd1ce9856a3c2854fd2776827d6c4b76f 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -20,7 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 14709bb6cbce4b3ae0f7ff859b0fa622c6eda293..f1f32699fee5f03f603f830722fe65622dee5d3e 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index e2ac7da2c2630725efe3dbcc51c3f3d30e7aca2c..b22ecb7c6dbb42a33a4f4d90b18b20816df16a50 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 5c010c9df23ba6c7732d87fa014879d93ff586ce..6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 6281d6c6533f7f49a269f5c7e52226ba0f1d29f6..a7f5a8f1698b9d02560de427d356e9e6be5caa7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 5798823cd54c66dd179e3611c0041f7c5a1ff2b5..4e0cf99d8e7ff45ed9145981b5e2e637ce4d4e4b 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 1864584adee357ce35a3e8a38a4e3c58c356bfca..6adc3c58de63ee70c26bed47eebef955893df4a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index a71fbcd901e8919949db5873675a7e3e785bdf4e..025ba827410f1a9f993a8a1855558a2daa86609b 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -20,7 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -37,11 +38,15 @@ class SoftmaxOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape logits_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), - errors::InvalidArgument("logits must be 2-dimensional")); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape), + errors::InvalidArgument("logits must have >= 1 dimension, got ", + logits_shape.DebugString())); - const int kBatchDim = 0; - const int kClassDim = 1; + // Major dimensions are batch dimensions, minor dimension is the class + // dimension. + std::vector batch_dims(logits_shape.dims() - 1); + std::iota(batch_dims.begin(), batch_dims.end(), 0); + const int kClassDim = logits_shape.dims() - 1; const DataType type = input_type(0); const xla::PrimitiveType xla_type = ctx->input_xla_type(0); @@ -55,7 +60,7 @@ class SoftmaxOp : public XlaOpKernel { xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. - auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); + auto shifted_logits = xla::Sub(logits, logits_max, batch_dims); auto exp_shifted = xla::Exp(shifted_logits); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); xla::PrimitiveType xla_accumulation_type; @@ -70,9 +75,9 @@ class SoftmaxOp : public XlaOpKernel { auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) - ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim}) + ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims) // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - : xla::Div(exp_shifted, sum, {kBatchDim}); + : xla::Div(exp_shifted, sum, batch_dims); ctx->SetOutput(0, softmax); } diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index faaf8964ff7c40d75a493b03e6b400632117cb45..aaeeae01ccb303091a6d37d1aeb4b2a3377dc638 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 8a8525efa186ed4aa02c494f7505f6245677e96e..7327258c31f21f45ff7ffffbc9db7a2a70b4a14c 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 47d282fe9ec664bbc424793e93f778ebb13c6877..4493539fe34f0ce635fdc58660d4ff90af9c9379 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 242638f981198ffd7a9c5b5f6365168de59a1f85..93fc14e9efca868e84444dd0e07d7f0dfa84c042 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index cc4b13d3b933cdc15efd94d3ce7a353a856bcb88..5412e135478361d08965e4621ec52cfb4a792f1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/prng.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index c2165ccd86dfa1c119790beb20af0844fb1bbda8..1062399d91bd9a9bf8c3820c5ecac534c110746d 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 26326f18b844fa9dc48aeedfa5dcff3d09033a18..be1814d8e3ae2c0ddad0134b9288e0ea084aa81b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index c9e56942625a009fb3660f413a845547192460d5..1233a37565d3a40c6dd2882b3139dedbf690a7b6 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 82d4a69777b06cc3dec1ceb1a0a4163dcb1e4667..183879c7602ccbbd74fca6cb9fa3fc94c066c37d 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -47,31 +47,12 @@ class TopKOp : public XlaOpKernel { context, last_dim_size >= k, errors::InvalidArgument("input must have at least k columns. Had ", last_dim_size, ", needed ", k)); - - xla::XlaBuilder* const b = context->builder(); if (last_dim_size < k) { k = last_dim_size; } - const xla::XlaOp input = context->Input(0); - - xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size); - auto input_dims = input_shape.dim_sizes(); - std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); - xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims); - xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32); - - std::vector start_indices(input_shape.dims(), 0); - std::vector limit_indices(input_dims.begin(), input_dims.end()); - limit_indices[last_dim] = k; - std::vector strides(input_shape.dims(), 1); - - xla::XlaOp values = - xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices, - limit_indices, strides)); - xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), - start_indices, limit_indices, strides); - context->SetOutput(0, values); - context->SetOutput(1, indices); + xla::XlaOp output_tuple = TopK(context->Input(0), k); + context->SetOutput(0, xla::GetTupleElement(output_tuple, 0)); + context->SetOutput(1, xla::GetTupleElement(output_tuple, 1)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 98df73024962b8009a74976d473df752d590b47a..be5e91138656716daddcc3c7a68dbb78ecb69103 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 6c721c48fe3af45aff5cd0bd5e74e2693faf9f97..f9148b394212777271f9eba51313ee17b19819af 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index e6ec794cfd4103f622f64a113464c2f4cbfd4215..0bdfc05726105e2d18362a691cbe2aab00bf77f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index f951127bb95cd52864af869676a6b4c4961c1a43..8671632976023fded04c26a9780c1a67638b0916 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index bb27b5d56f3c24dc093a60e698b1080dfb76514d..2c92a585f5679242d672d0402e617ff199b94f17 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 009fdd81b22dcf54e2ed6ee539414a9f5cafb52d..296518229ebf0ba46717afc4f26d5ae1551c2862 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -21,7 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -300,6 +301,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc deleted file mode 100644 index 661505021f820e2a87a5d414c6fe382bf6153045..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's backend registration modules. - -#include // NOLINT -#include - -#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static BackendRegistrationFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new BackendRegistrationFlags; - flags->tf_enable_prng_ops_gpu = false; - flag_list = new std::vector({ - Flag("tf_enable_prng_ops_gpu", &flags->tf_enable_prng_ops_gpu, - "Whether to enable PRNG ops: [RandomStandardNormal | RandomUniform " - "| RandomUniformInt | TruncatedNormal] on GPU."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// backend registration modules. -void AppendBackendRegistrationFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the BackendRegistrationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BackendRegistrationFlags* GetBackendRegistrationFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h deleted file mode 100644 index 861c923dd51f90be2acbeb23911a93e873aabdce..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_ - -// Legacy flags for the XLA bridge's backend registration modules. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// backend registration modules. -void AppendBackendRegistrationFlags(std::vector* append_to); - -// The values of flags associated with the XLA bridge's backend registration -// module. -typedef struct { - // Whether to enable RandomUniform op on GPU backend. - // TODO (b/32333178): Remove this flag or set its default to true. - bool tf_enable_prng_ops_gpu; -} BackendRegistrationFlags; - -// Return a pointer to the BackendRegistrationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BackendRegistrationFlags* GetBackendRegistrationFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_ diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 30039e256a184197109b334c5b0e0ca4770dbb2d..cb7a40e23d539f758d963791f1c2b4d37374ade5 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -44,9 +44,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -59,9 +59,9 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:protos_all_cc", ], ) @@ -78,12 +78,12 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -100,9 +100,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -119,10 +119,10 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -142,7 +142,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -162,8 +162,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -200,8 +200,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 3c4eec081ba9744226cfbd8d5392220cbf7276f3..f666d22ea44216beef74608bb4d9f33fb2fe82c6 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index d07a9486f18c0b8f26782123a8fba4ba228f71ee..8757b16a1ca6a8cec5e3c801c885e7bbbb2f2c76 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 35b137aa2cc0b5e6c2d2b917c0a95410522305c2..87d73eb3f07ebd7dfa4fef50ebe76cad0c4ed117 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 0f6e0e9d152ec5daedeb9c0e355bfb9731759094..1bef9bb166c576ec665bb48265b4da200ddca2a0 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index 9c8ac7af25e4222f35bedd3816fc817af7e1f068..fc0c1ee838190b1f1a7ca5b901c97e0a35232a97 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index 3aa6a9b07539487b954b2d8c8d0e0bbcc49c2b42..abd2316ac961f583dd29f90f43cf6209de30bd6a 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc index 8ff10fbd3fbf9308140af84c752a5a50bec8fd32..5e7cf00ee5e063aef36a9531ff87d8fe6928ca1f 100644 --- a/tensorflow/compiler/tf2xla/lib/random.cc +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h index 2c573fd85b2783fdac13457cdb277cf988ac40c4..59fc5d0433a51328bc78006ab1c3495d908b44ac 100644 --- a/tensorflow/compiler/tf2xla/lib/random.h +++ b/tensorflow/compiler/tf2xla/lib/random.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 739032fef7759daee5d10d209ead5e1ffa60ef8c..ba22eff73abab11abeb57283c63318b2e50a9ca1 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 87309e10ede320a81d173cd0a64492f88a2c7376..13a5f1b850a612bddeeac39bef431c19925351ca 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 75c0ad7f7e42cd27d41b723124119e369cf1a35e..04fa10108cef66f429392951eea70e59643a2d29 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -22,7 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 2dce620ba83785e44cc37d39b14fc25a08d074b0..555760b7efabddfb25c9135b109a1c48b487415e 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index a29496dec44798eb0b16bf59b7b84e48c6bdd56e..aeebf16028d40189203cdfd815f06a339ee72902 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index a6f5d346cb5ecb85ff6b2306c2502ba31d74cc64..8b5beba383cda45d36e2ee27ca5e3b3c5988b6b7 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 6cb6c088e9d20af05193f0a3da6c2595966eb495..b4905c952820a45371e090aa98466654e2db9661 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 574e70ddeeab8a3041cd730ce2717daec4f82ddf..d64394f1401d7ceea004a59c991ef6f4a1c58b41 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 5b6684c995889efbb1378c7ac4903548891d090a..9493b1f109be0725f7f733b9f9da664264275a69 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2fb66913ada375d53512b9a1115326b3cc2afea4..77da1bf29ced60e490f07abad41cf8ce96232982 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -32,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(), + host_tensor->shape(), &xla_shape)); + return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); +} + +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal) { + *literal = xla::MutableBorrowingLiteral( + static_cast(DMAHelper::base(host_tensor)), xla_shape); + + return Status::OK(); +} + Status HostTensorsToBorrowingLiteralTuple( tensorflow::gtl::ArraySlice host_tensors, xla::BorrowingLiteral* literal) { diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 0610a57029e72dff79a84742346f78a42b7f4ff1..09d6fa811669b422532673540e4da47f47e6be4e 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -30,6 +30,16 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); +// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer +// owned by 'host_tensor', but is mutable via the xla::Literal methods. +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal); // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index ac768b206e2a8d163a4253432a1911152f89ce86..48568c825b7a0f13011d3d6e8e62ec5db026760f 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index d02fc56c5b8f58f0e4cfe1779ad34fe3b79324c7..432a12a51622b56ae74a677420da321c58960ee6 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/graph.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index f0b30dcf4e98faa8ab0801a8b0583744bdc669c7..56f7045a98201ed398244f9e3f5ff23788135b75 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fe7ec633eca2504faf6cbb2f5fd7f59780ab7976..e89f4733281194f0263ae8cc4907caa0ad781165 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index d0b9e34e162f3412cd6662a2e2bbfe3df213c4c2..a6e78825334fec748be5fee80669649df699d2fb 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 672e19bd93449ccc31f4af5ded23257b197a3c39..334459138b55a201c15cb87ad9feb6a03a13c5ab 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include -#include "tensorflow/compiler/aot/runtime.h" namespace tensorflow { @@ -26,20 +26,29 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, result_index_(static_data.result_index), args_(new void*[static_data.num_args]), temps_(new void*[static_data.num_temps]), + arg_index_to_temp_index_(new int32[static_data.num_args]), + num_args_(static_data.num_args), arg_names_(static_data.arg_names), result_names_(static_data.result_names), program_shape_(static_data.program_shape), hlo_profile_printer_data_(static_data.hlo_profile_printer_data) { // Allocate arg and temp buffers. if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + alloc_args_ = cpu_function_runtime::MallocContiguousBuffers( static_data.arg_sizes, static_data.num_args, args_, /*annotate_initialized=*/false); } - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers( static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); + for (int i = 0; i < static_data.num_temps; i++) { + if (static_data.temp_sizes[i] < -1) { + int32 param_number = -(static_data.temp_sizes[i] + 2); + arg_index_to_temp_index_[param_number] = i; + } + } + // If Hlo profiling is enabled the generated code expects an appropriately // sized buffer to be passed in as the last argument. If Hlo profiling is // disabled the last function argument is still present in the function @@ -50,11 +59,24 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, } } +bool XlaCompiledCpuFunction::Run() { + // Propagate pointers to the argument buffers into the temps array. Code + // generated by XLA discovers the incoming argument pointers from the temps + // array. + for (int32 i = 0; i < num_args_; i++) { + temps_[arg_index_to_temp_index_[i]] = args_[i]; + } + raw_function_(temps_[result_index_], &run_options_, nullptr, temps_, + profile_counters_); + return true; +} + XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); + cpu_function_runtime::FreeContiguous(alloc_args_); + cpu_function_runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] arg_index_to_temp_index_; delete[] profile_counters_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 48a8c083cacf2f6ecf9dc1817b6174c01385d035..27cfb354bf5f8ede2dcca85065411006c352a575 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -60,9 +60,19 @@ class XlaCompiledCpuFunction { // The raw function to call. RawFunction raw_function; - // Cardinality and sizes of arg and temp buffers. + // Cardinality and size of arg buffers. const intptr_t* arg_sizes = nullptr; size_t num_args = 0; + + // Cardinality and size of temp buffers. + // + // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer. + // + // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The + // corresponding entry in the temp buffer array needs to be set to null. + // + // If temp_sizes[i] < -1 then the i'th temp is the entry parameter + // -(temp_sizes[i] + 2). const intptr_t* temp_sizes = nullptr; size_t num_temps = 0; @@ -113,11 +123,7 @@ class XlaCompiledCpuFunction { // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. - bool Run() { - raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_, profile_counters_); - return true; - } + bool Run(); // Returns the error message from the previous failed Run call. // @@ -224,6 +230,17 @@ class XlaCompiledCpuFunction { void** args_ = nullptr; void** temps_ = nullptr; + // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for + // XLA generated code to be able to find it. + // + // For now we need to keep around the args_ array because there is code that + // depends on args() returning a void**. However, in the future we may remove + // args_ in favor of using temps_ as the sole storage for the arguments. + int32* arg_index_to_temp_index_; + + // The number of incoming arguments. + int32 num_args_; + // Backing memory for individual arg and temp buffers. void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index cb47581e36bc302ec2789ffbbd962d9dccc368e5..226c89bcf1e66b5afb43cddb03db39b931ca55a8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -28,12 +28,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -688,12 +690,12 @@ Status ValidateFunctionDef(const FunctionDef* fdef, Status ValidateGraph(const Graph* graph, const FunctionLibraryDefinition& flib_def, const DeviceType& device_type, const string& name) { - auto maybe_error = [&](const string& op, const Status& s) -> Status { + auto maybe_error = [&](const Node* node, const Status& s) -> Status { if (!s.ok()) { return errors::InvalidArgument(strings::StrCat( "Detected unsupported operations when trying to compile graph ", name, - " on ", device_type.type_string(), ": ", op, " (", s.error_message(), - ")")); + " on ", device_type.type_string(), ": ", node->def().op(), " (", + s.error_message(), ")", FormatNodeForError(*node))); } return Status::OK(); }; @@ -706,15 +708,15 @@ Status ValidateGraph(const Graph* graph, Status s; if (fdef) { s = ValidateFunctionDef(fdef, flib_def); - TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + TF_RETURN_IF_ERROR(maybe_error(node, s)); continue; } const OpDef* op_def; s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def); - TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + TF_RETURN_IF_ERROR(maybe_error(node, s)); TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); s = FindKernelDef(device_type, node->def(), nullptr, nullptr); - TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + TF_RETURN_IF_ERROR(maybe_error(node, s)); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 079c99797e1f1ec26205e33b3c7c16d3764f15ca..25332c8d8e3210a0217a1ba3f5767115fe6b1d93 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -251,6 +252,12 @@ class XlaCompiler { // The default empty value is invalid. DeviceType device_type = DeviceType(""); + // The device to use during compilation to execute instructions on, for + // example for auto-tuning. + // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. + // -1 indicates the default device should be used. + int device_ordinal = -1; + xla::Client* client = nullptr; // Function library in which to find function definitions. Must be non-null. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 2fb93be01d6bf4dca22b74f64c1d6c8b0d7f6fb5..be00ed8813fdf2778d6af81556001ef51538dd34 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -312,7 +312,7 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { str_util::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[Node: C = Reshape")) + str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -1077,6 +1077,8 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { ASSERT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) + << status.error_message(); } // Tests a graph which has a node with invalid data type. @@ -1101,6 +1103,8 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { EXPECT_TRUE(str_util::StrContains(status.error_message(), "is not in the list of allowed values")) << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + << status.error_message(); } TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { @@ -1122,9 +1126,10 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: NoOp")) + EXPECT_TRUE( + str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 0dea366476954123ec09c020d493061fa2637c2f..b24e3aabbe6ba858a8bfb4dd435726984cc7b0f5 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -25,7 +25,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 38d8cd653cbbe5b01325d6b478589d88909bac56..3db37afdba71342cfb20af8841a40cb54709ca73 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index dc98d4fda6ae21411065981a7b7383ef0ad50f44..1398e9ee536a9675e5b703ec3fabf4a8b9d89cbf 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" @@ -21,20 +20,6 @@ limitations under the License. namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { - // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to - // slow code. - legacy_flags::BackendRegistrationFlags* flags = - legacy_flags::GetBackendRegistrationFlags(); - VLOG(2) << "flags->tf_enable_prng_ops_gpu: " << flags->tf_enable_prng_ops_gpu; - if (!flags->tf_enable_prng_ops_gpu && - (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || - kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal")) { - return false; - } - // TODO(b/26783907): The GPU backend currently does not implement sort. - if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { - return false; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 4d1b3b1a135c493b3c2fa95967d2c4768ecfa63f..8efb3d55c88757b9366bdf9622287bdd0a72e295 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -26,7 +26,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index d6ca4ab9346593892917e8375b07a8790dc26e79..e6522157a535fc3e4ec96cb0496b6be2e525c336 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 9e17756b27733e2453ea1688d13e1d718c25cfc8..114a9241bdb00526df76478b030a9efa506dd29c 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -57,11 +58,15 @@ xla::StatusOr> ComputeTempSizes( std::vector temp_sizes; temp_sizes.reserve(allocations.size()); for (const xla::BufferAllocation& allocation : allocations) { - // Callers don't allocate temporary buffers for parameters. Nor for - // thread-local buffers, which are lowered to alloca. - if (allocation.is_entry_computation_parameter() || - allocation.is_thread_local()) { + if (allocation.is_constant() || allocation.is_thread_local()) { + // Constants are lowered to globals. Thread locals are lowered to + // allocas. temp_sizes.push_back(-1); + } else if (allocation.is_entry_computation_parameter()) { + // Entry computation parameters need some preprocessing in + // XlaCompiledCpuFunction::Run. See the comment on + // XlaCompiledCpuFunction::StaticData::temp_sizes. + temp_sizes.push_back(-allocation.parameter_number() - 2); } else { temp_sizes.push_back(allocation.size()); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index e8eafb3819f420bf4fd5be2c3930b3da75be58d6..82028c8b9ca9f65a73f8b50edc0a47c7068aba9a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -21,7 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/dma_helper.h" diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 6203cffd806a94743e2df7adc1e52d0255ab7c06..ac9dfe3369078df7392a4ef04679f7d7beacf8bb 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,7 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index baea8149658ec0849ebb570931ca68518ec5284e..7928fa034725206a752cbfe086d01f15cd235df9 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 4de18a77887496d30e3b1407ecd9042e619653af..2438490be13809b9f3571a362900b44cb838e76b 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index f1c383fd9e3fff8a306ba0ddcc3f9ee42c63d66a..fdf13bb18c2567d2994612d15119ae87cbfa9137 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -636,7 +636,7 @@ cc_library( ":window_util", ":xla_data_proto", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ea75ad32d5df7bbadd37e89de6144b264ab6d5d1..2d5d078aa77423cc18bab053b80a7576acbd849e 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -409,7 +409,7 @@ class Array { // Returns the total number of elements in the array. int64 num_elements() const { - return std::accumulate(sizes_.begin(), sizes_.end(), 1, + return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, std::multiplies()); } diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 25666cad40e8f73812593635ccd0ef56fcd9b955..ad3fcee05b80181369bfdf3cdcdb5452ec9e7e89 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -64,6 +64,7 @@ cc_library( hdrs = ["client.h"], deps = [ ":global_data", + ":xla_computation", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:service_interface", @@ -73,7 +74,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", @@ -100,12 +100,12 @@ cc_library( deps = [ ":client", ":executable_build_options", + ":xla_computation", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", @@ -114,6 +114,7 @@ cc_library( "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:support", @@ -126,11 +127,11 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", + ":xla_computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", @@ -174,3 +175,60 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", ], ) + +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], +) + +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], + deps = [ + ":padding", + ":sharding_builder", + ":xla_computation", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":xla_builder", + ":xla_computation", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 8e54311bade531309c70fa045b1e95eb4fce844d..d0ce5e8a6afa262d4cffdfe8431aab570ffd28df 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index d751e183dd6af85094000d39f85cd205ce68cafc..be50cebfcc0e3c19002635dbd280b14048aa0c93 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index 332c96503637344d56e363e19db4880c37ca9684..a551edeab0943ec5213c5cb035644c02c3cf54d7 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 77ba474cf6915ebe56eb7f847f59e52d8b73ffdd..a2f32ab97eab10294a607f35fc79ded1cc2c5792 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -29,8 +29,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -45,7 +45,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", ], ) @@ -58,7 +58,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -72,7 +72,7 @@ cc_library( ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", ], ) @@ -86,7 +86,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -101,7 +101,7 @@ cc_library( ":constants", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) @@ -115,7 +115,31 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + ":arithmetic", + ":constants", + "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "pooling_test", + srcs = ["pooling_test.cc"], + deps = [ + ":pooling", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -131,11 +155,42 @@ cc_library( ":numeric", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) +cc_library( + name = "sorting", + srcs = ["sorting.cc"], + hdrs = ["sorting.h"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + ], +) + +xla_test( + name = "sorting_test", + srcs = ["sorting_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":sorting", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], @@ -150,8 +205,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index de1d785e199b0f5f448c5e99bdfb972f11972d63..9225b1acd69c214d6f08a45372a8082ed789c18c 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 8367e09450fb376f63de0bf19055eaba2e466318..632e8cc8bc64fad236a0226c6e93079aadde7050 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index b47f5243f008ecb2045456e4505d1a571fbed745..0c8a9b8cc02ba0c1ebdf6a060d4b99262dceb178 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc index f1e3439862344c01af15ec0571155ca46a579e54..f4320f65c1f76d4d4c384110b39d6606773aaf01 100644 --- a/tensorflow/compiler/xla/client/lib/constants_test.cc +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index d003d529cc316dfde63f76284f98ae698e1d8034..13db2325569cf2e25e3ff1200adf4b2544dc2f73 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 1df287d7db2fb5498900d1bff51b621915a6b0af..14c259a7fa2a47642663b65d2785e5bbdc040cfd 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h index 212f6583137390fe1e41bb88b71ba041e2d22ff3..efd8cdc25724198633e0bf1c48c4e7d9e4b4c9e1 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc index f56cadc5472cb4df6ad73c7cf27b14dce528761c..8a96ec68d2dca8485215258b1f6731b934e6f2a8 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..7199269a6c889f3589c1148687faf0bb2aaae90a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" + +namespace xla { + +namespace { + +// Common computation shared between AvgPool and AvgPoolGrad. Divide each +// element of an image by the count of elements that contributed to that +// element during pooling. +XlaOp AvgPoolDivideByCountWithGeneralPadding( + XlaOp sums, PrimitiveType dtype, + tensorflow::gtl::ArraySlice input_shape, + tensorflow::gtl::ArraySlice> spatial_padding, + tensorflow::gtl::ArraySlice ksize, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + // The padding shouldn't be included in the counts. We use another + // ReduceWindow to find the right counts. + const int num_spatial_dims = spatial_padding.size(); + + std::vector input_dim_sizes(num_spatial_dims); + std::vector window_dims(num_spatial_dims); + std::vector window_ksize(num_spatial_dims); + std::vector window_stride(num_spatial_dims); + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_dim_sizes[i] = input_shape[dim]; + window_dims[i] = dim; + window_ksize[i] = ksize[dim]; + window_stride[i] = stride[dim]; + } + + XlaBuilder* b = sums.builder(); + // Build a matrix of all 1s, with the same width/height as the input. + auto ones = Broadcast(One(b, dtype), input_dim_sizes); + PaddingConfig padding_config; + for (int i = 0; i < num_spatial_dims; ++i) { + auto dims = padding_config.add_dimensions(); + dims->set_edge_padding_low(spatial_padding[i].first); + dims->set_edge_padding_high(spatial_padding[i].second); + } + auto zero = Zero(b, dtype); + auto padded_ones = Pad(ones, zero, padding_config); + + // Perform a ReduceWindow with the same window size, strides, and padding + // to count the number of contributions to each result element. + auto counts = + ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b), + window_ksize, window_stride, Padding::kValid); + + return Div(sums, counts, window_dims); +} + +// Sums all elements in the window specified by 'kernel_size' and 'stride'. +XlaOp ComputeSums(XlaOp operand, XlaOp init_value, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value)); + PrimitiveType accumulation_type = init_shape.element_type(); + auto add_computation = CreateScalarAddComputation(accumulation_type, b); + return ReduceWindow(operand, init_value, add_computation, kernel_size, + stride, Padding::kValid); + }); +} + +// Creates a padding configuration out of spatial padding values. +PaddingConfig MakeSpatialPaddingConfig( + tensorflow::gtl::ArraySlice> spatial_padding, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + PaddingConfig padding_config; + for (int i = 0; i < 2 + num_spatial_dims; ++i) { + padding_config.add_dimensions(); + } + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + auto padding_dimension = padding_config.mutable_dimensions(dim); + padding_dimension->set_edge_padding_low(spatial_padding[i].first); + padding_dimension->set_edge_padding_high(spatial_padding[i].second); + } + return padding_config; +} + +} // namespace + +XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto max_computation = CreateScalarMaxComputation(dtype, b); + auto init_value = MinValue(b, dtype); + return ReduceWindow(operand, init_value, max_computation, kernel_size, + stride, padding); + }); +} + +XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice> padding, + const TensorFormat& data_format, + const bool counts_include_padding) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto init_value = Zero(b, dtype); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + auto padding_config = + MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format); + auto padded_operand = Pad(operand, Zero(b, dtype), padding_config); + auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride, + data_format); + if (counts_include_padding) { + // If counts include padding, all windows have the same number of elements + // contributing to each average. Divide by the window size everywhere to + // get the average. + int64 window_size = + std::accumulate(kernel_size.begin(), kernel_size.end(), 1, + [](int64 x, int64 y) { return x * y; }); + + auto divisor = ConstantR0WithType(b, dtype, window_size); + return pooled / divisor; + } else { + return AvgPoolDivideByCountWithGeneralPadding( + pooled, dtype, input_size, padding, kernel_size, stride, data_format); + } + }); +} + +std::vector> MakeSpatialPadding( + tensorflow::gtl::ArraySlice input_size, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + std::vector input_spatial_dimensions; + std::vector kernel_size_spatial_dimensions; + std::vector stride_spatial_dimensions; + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_spatial_dimensions.push_back(input_size[dim]); + kernel_size_spatial_dimensions.push_back(kernel_size[dim]); + stride_spatial_dimensions.push_back(stride[dim]); + } + return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions, + stride_spatial_dimensions, padding); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..1699c585d3b09a306c21cfa797a9023a8463bd1f --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +// Tensor format for reduce window operations. +class TensorFormat { + public: + TensorFormat(int batch_dimension, int feature_dimension, + tensorflow::gtl::ArraySlice spatial_dimensions) + : batch_dimension_(batch_dimension), + feature_dimension_(feature_dimension), + spatial_dimensions_(spatial_dimensions.begin(), + spatial_dimensions.end()) {} + + int batch_dimension() const { return batch_dimension_; } + + int feature_dimension() const { return feature_dimension_; } + + int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } + + int num_spatial_dims() const { return spatial_dimensions_.size(); } + + private: + // The number of the dimension that represents the batch. + int batch_dimension_; + // The number of the dimension that represents the features. + int feature_dimension_; + // The dimension numbers for the spatial dimensions. + tensorflow::gtl::InlinedVector spatial_dimensions_; +}; + +// Computes the max pool of 'operand'. +XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool of 'operand'. +XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice> padding, + const TensorFormat& data_format, + const bool counts_include_padding); + +// Returns the list of low and high padding elements in each spatial dimension +// for the given 'padding' specification. +std::vector> MakeSpatialPadding( + tensorflow::gtl::ArraySlice input_size, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b4553b60db555ad7c2ab6b695236df745e30683 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +TensorFormat MakeNCHWFormat(int num_spatial_dims) { + tensorflow::gtl::InlinedVector spatial_dimensions; + for (int i = 0; i < num_spatial_dims; ++i) { + spatial_dimensions.push_back(i + 2); + } + return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1, + /*spatial_dimensions=*/spatial_dimensions); +} + +std::vector> MakeGeneralPadding( + XlaOp input, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const xla::TensorFormat& data_format) { + XlaBuilder* b = input.builder(); + Shape operand_shape = b->GetShape(input).ValueOrDie(); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + return MakeSpatialPadding(input_size, kernel_size, stride, padding, + data_format); +} + +// Add singleton batch and feature dimensions to spatial dimensions, according +// to 'data_format' specification. +std::vector ExpandWithBatchAndFeatureDimensions( + tensorflow::gtl::ArraySlice spatial_dim_sizes, + const xla::TensorFormat& data_format) { + const int num_spatial_dims = spatial_dim_sizes.size(); + std::vector tensor_sizes(num_spatial_dims + 2, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + tensor_sizes[dim] = spatial_dim_sizes[i]; + } + return tensor_sizes; +} + +class PoolingTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(PoolingTest, MaxPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kValid, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 5}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}}, + {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/true); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, + {{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = kernel_size; + AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, + AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {}, + error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 299a6ac2b630e94567becc3ec139b8c24eab396a..6ef81689489d8117d5951bcb75693c2e3413e4d6 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" @@ -56,7 +56,7 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { // Performs a single round of the Threefry2x32 algorithm, with a rotation // amount 'rotation'. - auto round = [builder](ThreeFry2x32State v, int rotation) { + auto round = [](ThreeFry2x32State v, int rotation) { v[0] = v[0] + v[1]; v[1] = RotateLeftS32(v[1], rotation); v[1] = v[0] ^ v[1]; diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index ac86390239668eeff1ad9eed0f6c82e10d5db004..ad000b1fa1d0655c8fccc0bb33379f2499b77f26 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc new file mode 100644 index 0000000000000000000000000000000000000000..a904be259a3870a679b2c4699ec01e2a11b1ce46 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" + +namespace xla { + +XlaOp TopK(XlaOp input, int64 k) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + int last_dim = input_shape.dimensions_size() - 1; + int last_dim_size = input_shape.dimensions(last_dim); + + XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + auto input_dims = input_shape.dimensions(); + std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); + XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); + XlaOp sort_result = Sort(Neg(input), broadcast_s32); + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + limit_indices[last_dim] = k; + std::vector strides(input_shape.dimensions_size(), 1); + + XlaOp values = Neg(Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides)); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + return Tuple(builder, {values, indices}); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.h b/tensorflow/compiler/xla/client/lib/sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..b9dfafdd6f957ae050e0f5dbd076d5288235b490 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopK(XlaOp input, int64 k); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fef98c9923096e21a755c6d730de2c7c10852b2d --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using SortingTest = ClientLibraryTestBase; + +XLA_TEST_F(SortingTest, TopK3From8Values) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopK(x, 3), 0); + ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Indices) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + ComputeAndCompareR1(&builder, {0, 1, 2}, {}); +} + +XLA_TEST_F(SortingTest, TopKFullSort) { + XlaBuilder builder(TestName()); + const int kSize = 16; + std::mt19937 eng; + std::uniform_real_distribution u_dist(0.0, 100.0); + auto gen = std::bind(u_dist, eng); + std::vector inputs(kSize); + std::generate(inputs.begin(), inputs.end(), gen); + auto x = ConstantR1(&builder, inputs); + xla::GetTupleElement(xla::TopK(x, kSize), 0); + + std::sort(inputs.begin(), inputs.end(), std::greater()); + ComputeAndCompareR1(&builder, inputs, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 534c5098683e8984fe013225e7de2fc48b1bbc1f..081fec7ad92958aa285e4be41394d7b1876e0815 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -98,14 +98,13 @@ std::vector> MakeFakeArgumentsOrDie( << "Computation should have progran shape."; auto program_shape = computation.proto().program_shape(); - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape.parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; + // Create and run a program which produces a tuple with one element per + // parameter, then return the tuple's constituent buffers. + std::vector param_shapes(program_shape.parameters().begin(), + program_shape.parameters().end()); + auto fake_input_tuple = + MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client); + return client->DeconstructTuple(*fake_input_tuple).ValueOrDie(); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index dc613099e2b42a60d0c11a654ab5cd41f8bd4f6f..03695ce2a339735e3e49522f4fe1bbf2d83a3834 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 5f9710914bd0ceff55f5b0a2db05e553ce8bd637..cffb24e29beda6a8c40dca2fe709be22892dd489 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/status_macros.h" using xla::source_map_util::InvalidParameterArgument; @@ -29,8 +31,8 @@ using xla::source_map_util::InvalidParameterArgument; namespace xla { namespace { -StatusOr BorrowStreamForDevice(int device_ordinal, - Backend* backend) { +StatusOr BorrowStreamForDevice(int device_ordinal, + Backend* backend) { if (device_ordinal < 0) { device_ordinal = backend->default_device_ordinal(); } @@ -99,11 +101,14 @@ Status LocalExecutable::ValidateExecutionOptions( } } - // Verify that the device the executable was built for is equivalent to the - // device it will run on. - int run_device_ordinal = run_options.device_ordinal() == -1 - ? backend_->default_device_ordinal() - : run_options.device_ordinal(); + // Verify that the device the executable was built for is equivalent + // to the device it will run on. + int run_device_ordinal = run_options.device_ordinal(); + if (run_device_ordinal == -1) { + run_device_ordinal = run_options.stream() != nullptr + ? run_options.stream()->parent()->device_ordinal() + : backend_->default_device_ordinal(); + } TF_ASSIGN_OR_RETURN(bool devices_equivalent, backend_->devices_equivalent( run_device_ordinal, build_options_.device_ordinal())); @@ -141,7 +146,7 @@ StatusOr LocalExecutable::Run( TF_RETURN_IF_ERROR( ValidateExecutionOptions(arguments, run_options, *backend_)); - Backend::StreamPtr stream; + StreamPool::Ptr stream; if (run_options.stream() == nullptr) { // NB! The lifetime of `stream` needs to match the lifetime of // `actual_options` (otherwise we will end up using a returned stream in @@ -298,7 +303,7 @@ StatusOr> LocalClient::TransferFromOutfeedLocal( const Shape& shape, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - auto literal = MakeUnique(); + auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, shape, literal.get())); return std::move(literal); diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 4d9e0d7cd9d6ddebead1e12b23e94b529038039b..ae23809261757c637ab4aec036750c371ac60cdc 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc similarity index 95% rename from tensorflow/compiler/xla/client/xla_client/xla_builder.cc rename to tensorflow/compiler/xla/client/xla_builder.cc index a9a4b3bc5dff83347d90afceb1a5f253f593be8a..b3b00e2fffe1196b36190ec72d1425bae4e4e276 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include #include @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/sharding_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -44,21 +45,6 @@ int64 GetUniqueId() { return id; } -// Returns true if an instruction with the given opcode can be the root of the -// computation. -bool CanBeRoot(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAfterAll: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kOutfeed: - case HloOpcode::kTrace: - return false; - default: - return true; - } -} - } // namespace XlaOp operator-(const XlaOp& x) { return Neg(x); } @@ -141,28 +127,13 @@ XlaOp XlaBuilder::ReportErrorOrReturn( return ReportErrorOrReturn(op_creator()); } -StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { +StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { TF_RETURN_IF_ERROR(first_error_); - - TF_RET_CHECK(root_id != nullptr); + TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size())); ProgramShape program_shape; - // Not all instructions can be roots. Walk backwards from the last added - // instruction until a valid root is found. - int64 index = instructions_.size() - 1; - for (; index >= 0; index--) { - TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instructions_[index].opcode())); - if (CanBeRoot(opcode)) { - break; - } - } - if (index < 0) { - return FailedPrecondition("no root instruction was found"); - } - *root_id = instructions_[index].id(); - *program_shape.mutable_result() = instructions_[index].shape(); + *program_shape.mutable_result() = instructions_[root_id].shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -187,8 +158,15 @@ StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { } StatusOr XlaBuilder::GetProgramShape() const { - int64 root; - return GetProgramShape(&root); + TF_RET_CHECK(!instructions_.empty()); + return GetProgramShape(instructions_.back().id()); +} + +StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return GetProgramShape(root.handle()); } void XlaBuilder::IsConstantVisitor(const int64 op_handle, @@ -256,17 +234,29 @@ StatusOr XlaBuilder::Build() { first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } + return Build(instructions_.back().id()); +} + +StatusOr XlaBuilder::Build(XlaOp root) { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return Build(root.handle()); +} + +StatusOr XlaBuilder::Build(int64 root_id) { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. - { - int64 root_id; - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), - GetProgramShape(&root_id)); - entry.set_root_id(root_id); - } + TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + entry.set_root_id(root_id); for (auto& instruction : instructions_) { // Ensures that the instruction names are unique among the whole graph. @@ -1098,11 +1088,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { sharding_builder::AssignDevice(0); XlaScopedShardingAssignment scoped_sharding(this, infeed_instruction_sharding); - TF_ASSIGN_OR_RETURN(infeed, - AddInstruction(std::move(instr), HloOpcode::kInfeed)); + TF_ASSIGN_OR_RETURN( + infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {})); } else { - TF_ASSIGN_OR_RETURN(infeed, - AddInstruction(std::move(instr), HloOpcode::kInfeed)); + TF_ASSIGN_OR_RETURN( + infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {})); } // The infeed instruction produces a tuple of the infed data and a token @@ -1634,6 +1624,32 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, }); } +XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); + TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape, + GetShape(scatter_indices)); + TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); + TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, + update_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferScatterShape( + input_shape, scatter_indices_shape, updates_shape, + to_apply_shape, dimension_numbers)); + + *instr.mutable_scatter_dimension_numbers() = dimension_numbers; + + AddCalledComputation(update_computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kScatter, + {input, scatter_indices, updates}); + }); +} + XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, @@ -1680,7 +1696,7 @@ XlaOp XlaBuilder::Reduce( TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferReduceShape( - operand_shape, init_shape, dimensions_to_reduce, + {&operand_shape, &init_shape}, dimensions_to_reduce, called_program_shape)); for (int64 dim : dimensions_to_reduce) { @@ -1865,6 +1881,61 @@ XlaOp XlaBuilder::CrossReplicaSum( }); } +XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + + // The HloInstruction for Alltoall currently only handles the data + // communication: it accepts N already split parts and scatters them to N + // cores, and each core gathers the N received parts into a tuple as the + // output. So here we explicitly split the operand before the hlo alltoall, + // and concat the tuple elements. + // + // First, run shape inference to make sure the shapes are valid. + TF_RETURN_IF_ERROR( + ShapeInference::InferAllToAllShape(operand_shape, split_dimension, + concat_dimension, split_count) + .status()); + + // Split into N parts. + std::vector slices; + slices.reserve(split_count); + const int64 block_size = + operand_shape.dimensions(split_dimension) / split_count; + for (int i = 0; i < split_count; i++) { + slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size, + /*limit_index=*/(i + 1) * block_size, + /*stride=*/1, /*dimno=*/split_dimension)); + } + + // Handle data communication. + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices)); + std::vector slice_shape_ptrs; + c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + TF_ASSIGN_OR_RETURN( + XlaOp alltoall, + AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices)); + + // Concat the N received parts. + std::vector received; + received.reserve(split_count); + for (int i = 0; i < split_count; i++) { + received.push_back(this->GetTupleElement(alltoall, i)); + } + return this->ConcatInDim(received, concat_dimension); + }); +} + XlaOp XlaBuilder::SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, @@ -2136,11 +2207,6 @@ StatusOr XlaBuilder::BuildConstantSubGraph( TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, LookUpInstruction(root_op)); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); - if (!CanBeRoot(opcode)) { - return InvalidArgument("the operand with opcode %s cannot be root", - root->opcode().c_str()); - } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. @@ -2666,6 +2732,13 @@ XlaOp CrossReplicaSum( replica_group_ids, channel_id); } +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups); +} + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, @@ -2802,6 +2875,13 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, window_bounds); } +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return input.builder()->Scatter(input, scatter_indices, updates, + update_computation, dimension_numbers); +} + void Send(const XlaOp& operand, const ChannelHandle& handle) { return operand.builder()->Send(operand, handle); } diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h similarity index 97% rename from tensorflow/compiler/xla/client/xla_client/xla_builder.h rename to tensorflow/compiler/xla/client/xla_builder.h index 8359d936b7ff2bb75fd36f07b71837779f3f1218..9403d7ca8dabc80a3964b50d29f158a98091f843 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ #include #include @@ -22,7 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -195,9 +195,14 @@ class XlaBuilder { // Builds the computation with the requested operations, or returns a non-ok // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. + // computation being returned. The root of the computation will be the last + // added operation. StatusOr Build(); + // Overload of Build which specifies a particular root instruction for the + // computation. + StatusOr Build(XlaOp root); + // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. // This function is intended to be used where the returned XlaComputation is @@ -225,9 +230,14 @@ class XlaBuilder { // Returns the shape of the given op. StatusOr GetShape(const XlaOp& op) const; - // Returns the (inferred) result for the current computation's shape. + // Returns the (inferred) result for the current computation's shape. This + // assumes the root instruction is the last added instruction. StatusOr GetProgramShape() const; + // Returns the (inferred) result for the current computation's shape using the + // given operation as the root. + StatusOr GetProgramShape(XlaOp root) const; + // Reports an error to the builder, by // * storing it internally and capturing a backtrace if it's the first error // (this deferred value will be produced on the call to @@ -255,6 +265,9 @@ class XlaBuilder { StatusOr IsConstant(const XlaOp& operand) const; private: + // Build helper which takes the id of the root operation.. + StatusOr Build(int64 root_id); + // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -686,9 +699,9 @@ class XlaBuilder { // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // - // - `channel_id`: for Allreduce nodes from different models, if they have the - // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross models. + // - `channel_id`: for Allreduce nodes from different modules, if they have + // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( @@ -697,6 +710,13 @@ class XlaBuilder { const tensorflow::gtl::optional& channel_id = tensorflow::gtl::nullopt); + // Enqueues an operation that do an Alltoall of the operand cross cores. + // + // TODO(b/110096724): This is NOT YET ready to use. + XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -857,6 +877,11 @@ class XlaBuilder { const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Enqueues a Scatter node onto the computation. + XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + // Enqueues a Send node onto the computation for device-to-device // communication, to send the given operand to a Recv instruction that shares // the same channel handle. @@ -964,9 +989,8 @@ class XlaBuilder { // shape. StatusOr Reshape(const Shape& shape, const XlaOp& operand); - // Returns the (inferred) result for the program shape for the current - // computation and fills the root_id in the pointer. - StatusOr GetProgramShape(int64* root_id) const; + // Returns the (inferred) result for the program shape using the given root. + StatusOr GetProgramShape(int64 root_id) const; // Returns shapes for the operands. StatusOr> GetOperandShapes( @@ -1229,6 +1253,9 @@ class XlaBuilder { const XlaOp& operand, const XlaComputation& computation, tensorflow::gtl::ArraySlice replica_group_ids, const tensorflow::gtl::optional& channel_id); + friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); friend XlaOp SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, @@ -1296,6 +1323,10 @@ class XlaBuilder { friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice window_bounds); + friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); friend void Send(const XlaOp& operand, const ChannelHandle& handle); friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); @@ -1811,9 +1842,9 @@ XlaOp CrossReplicaSum( // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // -// - `channel_id`: for Allreduce nodes from different models, if they have the +// - `channel_id`: for Allreduce nodes from different modules, if they have the // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be -// applied cross models. +// applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, @@ -1821,6 +1852,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, const tensorflow::gtl::optional& channel_id = tensorflow::gtl::nullopt); +// Enqueues an operation that do an Alltoall of the operand cross cores. +// +// TODO(b/110096724): This is NOT YET ready to use. +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups = {}); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1977,6 +2015,11 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice window_bounds); +// Enqueues a Scatter node onto the computation. +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + // Enqueues a Send node onto the computation for device-to-device // communication. This operation sends the given operand to // a Recv instruction in a different computation that shares the same channel @@ -2238,4 +2281,4 @@ XlaOp ConstantR4FromArray4D(XlaBuilder* builder, } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc similarity index 81% rename from tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc rename to tensorflow/compiler/xla/client/xla_builder_test.cc index 3b8beb2c7840e23752b5f47bbc5f55d89751884d..49a15ec3b449bdec07aa6ecfbc40b7b9f62c3f4e 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -45,6 +47,17 @@ class XlaBuilderTest : public ::testing::Test { return HloModule::CreateFromProto(proto, config); } + // Overload which explicitly specifies the root instruction. + StatusOr> BuildHloModule(XlaBuilder* b, + XlaOp root) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); + } + // Returns the name of the test currently being run. string TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); @@ -292,6 +305,21 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, AllToAll) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, + /*split_count=*/2); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); @@ -319,5 +347,45 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); } +TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { + XlaBuilder b(TestName()); + XlaOp constant = ConstantR0(&b, 1.0); + Add(constant, ConstantR0(&b, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { + // Specifying a particular root in Build should still include all entry + // parameters. + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + XlaOp x = Parameter(&b, 0, shape, "x"); + XlaOp y = Parameter(&b, 1, shape, "y"); + XlaOp z = Parameter(&b, 2, shape, "z"); + Add(x, Sub(y, z)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter()); + EXPECT_EQ(module->entry_computation()->num_parameters(), 3); + EXPECT_EQ(module->entry_computation()->instruction_count(), 5); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { + XlaBuilder b(TestName()); + XlaBuilder other_b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + + Parameter(&b, 0, shape, "param"); + XlaOp other_param = Parameter(&other_b, 0, shape, "other_param"); + + Status status = b.Build(other_param).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("root operation is not in this computation")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD deleted file mode 100644 index 763653c685cb0d821c11bf3e25f526db0dcb4945..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ /dev/null @@ -1,79 +0,0 @@ -# Description: -# The new XLA client libraries. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "xla_computation", - srcs = ["xla_computation.cc"], - hdrs = ["xla_computation.h"], - deps = [ - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_proto", - ], -) - -cc_library( - name = "xla_builder", - srcs = ["xla_builder.cc"], - hdrs = ["xla_builder.h"], - visibility = ["//visibility:public"], - deps = [ - ":xla_computation", - "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client:sharding_builder", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "xla_builder_test", - srcs = ["xla_builder_test.cc"], - deps = [ - ":xla_builder", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/core:test", - ], -) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc similarity index 94% rename from tensorflow/compiler/xla/client/xla_client/xla_computation.cc rename to tensorflow/compiler/xla/client/xla_computation.cc index 72e3935696e0c44ae3893fc8f1ceb261fa5e2646..3543d41fc2656ec028646edebc0bf5b6af7f67a5 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h similarity index 90% rename from tensorflow/compiler/xla/client/xla_client/xla_computation.h rename to tensorflow/compiler/xla/client/xla_computation.h index 0ffba208b1f8683fe1d26107cbfd096b856267f1..71598ef8b296a760b0ee818fce0a59aed5cfc6b4 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ #include @@ -64,4 +64,4 @@ class XlaComputation { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index abd10b164eaef8e75ed304483861baf250c5b954..fb135f5ceda67ce6c001de15b8f3f084ca164826 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -20,7 +20,7 @@ from __future__ import print_function import math -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import xla_shape @@ -85,7 +85,7 @@ class Sharding(object): something we really want to expose to users (especially as the contract for tile_assignment is very strict). """ - if not isinstance(tile_assignment, np.ndarray): + if not isinstance(tile_assignment, _np.ndarray): raise TypeError('Tile assignment must be of type np.ndarray') if not isinstance(tile_shape, xla_shape.Shape): raise TypeError('Tile shape must be of type xla_shape.Shape') diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 15eeb2ea13607d43c995197f8f0e3c58abd4d94a..b72d190d54591384392e79e73e90cf52df04a902 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -297,7 +297,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } - CHECK(IsDenseArray(shape)); + CHECK(IsDenseArray(shape)) << shape.ShortDebugString(); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 0545deb096e9eace5a9713f200e10559aa718441..36e472568ecfdb97c828817ed339260ee7878723 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -71,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) { return out; } -Literal::StrideConfig::StrideConfig( +MutableLiteralBase::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, tensorflow::gtl::ArraySlice dimensions) : dimensions(dimensions), @@ -133,7 +133,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { } Literal::Literal(const Shape& shape, bool allocate_arrays) - : LiteralBase(), shape_(MakeUnique(shape)) { + : MutableLiteralBase() { + shape_ = MakeUnique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); root_piece_->set_subshape(shape_.get()); @@ -159,7 +160,9 @@ void Literal::DeallocateBuffers() { }); } -Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } +Literal::Literal(Literal&& other) : MutableLiteralBase() { + *this = std::move(other); +} Literal& Literal::operator=(Literal&& other) { DCHECK(&other.root_piece_->subshape() == other.shape_.get()); @@ -187,12 +190,13 @@ const SparseIndexArray* LiteralBase::sparse_indices( return piece(shape_index).sparse_indices(); } -SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { +SparseIndexArray* MutableLiteralBase::sparse_indices( + const ShapeIndex& shape_index) { return piece(shape_index).sparse_indices(); } template -Status Literal::CopySliceFromInternal( +Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { @@ -225,8 +229,8 @@ Status Literal::CopySliceFromInternal( // proper stride size at the matching dimension. DimensionVector src_indexes(src_base.size(), 0); DimensionVector dest_indexes(dest_base.size(), 0); - Literal::StrideConfig stride_config(src_literal.shape(), shape(), - copy_size); + MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { // Map from multi-dimensional index, to source index. @@ -253,9 +257,10 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { +Status MutableLiteralBase::CopyElementFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( src_literal.shape(), src_index); @@ -275,8 +280,8 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } -/* static */ StatusOr> Literal::CreateFromProto( - const LiteralProto& proto) { +/* static */ StatusOr> +MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -405,9 +410,9 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { return Status::OK(); } -Status Literal::CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index, - const ShapeIndex& src_shape_index) { +Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { const Shape& dest_subshape = ShapeUtil::GetSubshape(shape(), dest_shape_index); const Shape& src_subshape = @@ -482,10 +487,11 @@ Status Literal::MoveFrom(Literal&& src_literal, return Status::OK(); } -Status Literal::CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status MutableLiteralBase::CopySliceFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) << ShapeUtil::HumanString(src_literal.shape()); @@ -543,7 +549,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, shape().element_type()); } -void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { +void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(element_count(), values.bits()); @@ -895,8 +901,8 @@ size_t LiteralBase::Hash() const { return hash_value; } -Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value) { +Status MutableLiteralBase::SetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index, int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -933,7 +939,7 @@ tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( return p.sparse_indices()->At(sparse_element_number); } -void Literal::SortSparseElements(const ShapeIndex& shape_index) { +void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } @@ -1391,11 +1397,11 @@ StatusOr> LiteralBase::ConvertToShape( elements.push_back(std::move(*new_element)); } auto converted = MakeUnique(); - *converted = Literal::MoveIntoTuple(&elements); + *converted = MutableLiteralBase::MoveIntoTuple(&elements); return std::move(converted); } -/* static */ Literal Literal::MoveIntoTuple( +/* static */ Literal MutableLiteralBase::MoveIntoTuple( tensorflow::gtl::MutableArraySlice elements) { std::vector element_shapes; for (const Literal& element : elements) { @@ -1808,7 +1814,8 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { - // These conditions should have been checked in Literal::CreateFromProto. + // These conditions should have been checked in + // MutableLiteralBase::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); @@ -1900,7 +1907,7 @@ const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } -void* Literal::untyped_data(const ShapeIndex& shape_index) { +void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } @@ -1916,6 +1923,127 @@ string LiteralBase::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } +void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, + Piece* src_piece, + Piece* dest_piece) { + DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape())) + << "src_piece has shape: " + << ShapeUtil::HumanString(src_piece->subshape()) + << "dest_piece has shape: " + << ShapeUtil::HumanString(dest_piece->subshape()); + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece); + + dest_piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + dest_piece->set_buffer(src_piece->buffer()); + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(dest_piece->size_bytes(), 0); + } +} + +MutableLiteralBase::~MutableLiteralBase() {} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableBorrowingLiteral& literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( + const MutableBorrowingLiteral& literal) { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); + + return *this; +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableLiteralBase& literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal->shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + MutableBorrowingLiteral literal, const ShapeIndex& view_root) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.piece(view_root).subshape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, + const Shape& shape) + : MutableLiteralBase() { + shape_ = MakeUnique(shape); + CHECK(LayoutUtil::HasLayout(*shape_)); + CHECK(!ShapeUtil::IsTuple(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast(src_buf_ptr)); + root_piece_->set_subshape(shape_.get()); +} + +MutableBorrowingLiteral::~MutableBorrowingLiteral() { + if (root_piece_ != nullptr) { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete piece->sparse_indices(); + } + }); + delete root_piece_; + } +} + +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { CHECK(ShapeUtil::IsTuple(shape)); for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { @@ -1932,13 +2060,6 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { } } -LiteralSlice::LiteralSlice(const LiteralBase& literal) - : LiteralBase(), root_piece_(&literal.root_piece()) {} - -LiteralSlice::LiteralSlice(const LiteralBase& literal, - const ShapeIndex& view_root) - : LiteralBase(), root_piece_(&literal.piece(view_root)) {} - BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : LiteralBase(), shape_(MakeUnique(shape)) { CHECK(ShapeUtil::IsArray(*shape_)); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index dd67dfa8d4a556aea179bc47abfdc9a9c8872c45..92c0f903cbe252a153103aa8514bb5531696bbfe 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -310,9 +310,10 @@ class LiteralBase { // type of literal itself (0 for numeric types, and false for predicates). // // Note: It's an antipattern to use this method then immediately call - // Literal::Populate on the result (since that results in zero initialization, - // then reinitialization. Conside if a call to MakeUnique(shape), - // followed by the call to Literal::Populate can be used instead. + // MutableLiteralBase::Populate on the result (since that results in zero + // initialization, then reinitialization. Conside if a call to + // MakeUnique(shape), followed by the call to + // MutableLiteralBase::Populate can be used instead. static std::unique_ptr CreateFromShape(const Shape& shape); protected: @@ -534,7 +535,7 @@ class LiteralBase { virtual const Piece& root_piece() const = 0; // LiteralSlice and Literal must access Pieces of other Literals. - friend class Literal; + friend class MutableLiteralBase; friend class LiteralSlice; friend class BorrowingLiteral; @@ -545,33 +546,10 @@ class LiteralBase { tensorflow::gtl::ArraySlice start_indices) const; }; -// Class representing literal values in XLA. -// -// The underlying buffer and shape is always owned by this class. -class Literal : public LiteralBase { +// Abstract base class representing a mutable literal in XLA. +class MutableLiteralBase : public LiteralBase { public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); - Literal& operator=(Literal&& other); - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return shape_.get(); } + virtual ~MutableLiteralBase() = 0; // Returns a MutableArraySlice view of the array for this literal for the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the @@ -587,6 +565,10 @@ class Literal : public LiteralBase { // is not a sparse array. SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + // Returns a pointer to the underlying buffer holding the array at the given // shape index. CHECKs if the subshape of the literal at the given ShapeIndex // is not array. @@ -613,21 +595,6 @@ class Literal : public LiteralBase { const ShapeIndex& dest_shape_index = {}, const ShapeIndex& src_shape_index = {}); - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - // Copies the values from src_literal, starting at src_base shape indexes, // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. @@ -730,12 +697,7 @@ class Literal : public LiteralBase { static StatusOr> CreateFromProto( const LiteralProto& proto); - private: - // Recursively sets the subshapes and buffers of all subpieces rooted at - // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in - // the shape. - void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - + protected: // Returns the piece at the given ShapeIndex. Piece& piece(const ShapeIndex& shape_index) { return const_cast(LiteralBase::piece(shape_index)); @@ -783,12 +745,83 @@ class Literal : public LiteralBase { template Status PopulateInternal(const FnType& generator, bool parallel); + friend class LiteralBase; + friend class MutableBorrowingLiteral; +}; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + +// The underlying buffer and shape is always owned by this class. +class Literal : public MutableLiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + virtual Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); + + private: // Deallocate the buffers held by this literal. void DeallocateBuffers(); - friend class LiteralBase; + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); +}; + +// The underlying buffer is not owned by this class and is always owned by +// others. The shape is not owned by this class and not mutable. +class MutableBorrowingLiteral : public MutableLiteralBase { + public: + virtual ~MutableBorrowingLiteral(); + + MutableBorrowingLiteral() : MutableLiteralBase() {} + + MutableBorrowingLiteral(const MutableBorrowingLiteral& literal); + MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal); + + // Implicit conversion constructors. + MutableBorrowingLiteral(const MutableLiteralBase& literal); + MutableBorrowingLiteral(MutableLiteralBase* literal); + MutableBorrowingLiteral(MutableBorrowingLiteral literal, + const ShapeIndex& view_root); + MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + + private: + // Recursively copies the subtree from the `src_piece` at the given child + // index to the `dest_piece`. For buffers only the pointers are copied, but + // not the content. + void CopyPieceSubtree(const Shape& shape, Piece* src_piece, + Piece* dest_piece); }; -std::ostream& operator<<(std::ostream& out, const Literal& literal); // A read-only view of a Literal. A LiteralSlice contains pointers to shape and // literal buffers always owned by others. @@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. Stored as unique_ptr so such that the (default) - // move construction of this class would be trivially correct: the pointer to - // Shape root_piece_ stores will still point to the correct address. + // Shape of this literal. Stored as unique_ptr such that the (default) move + // construction of this class would be trivially correct: the pointer to Shape + // root_piece_ stores will still point to the correct address. std::unique_ptr shape_; }; @@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice LiteralBase::data( } template -tensorflow::gtl::MutableArraySlice Literal::data( +tensorflow::gtl::MutableArraySlice MutableLiteralBase::data( const ShapeIndex& shape_index) { return piece(shape_index).data(); } @@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get( } template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value) { return piece(shape_index).Set(multi_index, value); } template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice multi_index, NativeT value) { return root_piece().Set(multi_index, value); } @@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, } template -void Literal::AppendSparseElement( +void MutableLiteralBase::AppendSparseElement( tensorflow::gtl::ArraySlice multi_index, NativeT value, const ShapeIndex& shape_index) { Piece& p = piece(shape_index); @@ -959,7 +993,8 @@ void LiteralBase::EachCell( } template -inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { +inline void MutableLiteralBase::PopulateR1( + tensorflow::gtl::ArraySlice values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); @@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { } template -void Literal::PopulateR2( +void MutableLiteralBase::PopulateR2( std::initializer_list> values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 2); @@ -996,7 +1031,7 @@ void Literal::PopulateR2( } template -void Literal::PopulateFromArray(const Array& values) { +void MutableLiteralBase::PopulateFromArray(const Array& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1009,24 +1044,24 @@ void Literal::PopulateFromArray(const Array& values) { } template -void Literal::PopulateR2FromArray2D(const Array2D& values) { +void MutableLiteralBase::PopulateR2FromArray2D(const Array2D& values) { PopulateFromArray(values); } template -void Literal::PopulateR3FromArray3D(const Array3D& values) { +void MutableLiteralBase::PopulateR3FromArray3D(const Array3D& values) { PopulateFromArray(values); } template -void Literal::PopulateR4FromArray4D(const Array4D& values) { +void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { PopulateFromArray(values); } template -void Literal::PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort) { +void MutableLiteralBase::PopulateSparse( + SparseIndexArray indices, tensorflow::gtl::ArraySlice values, + bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); int rank = ShapeUtil::Rank(shape()); CHECK_EQ(indices.rank(), rank); @@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices, } template -Status Literal::PopulateInternal(const FnType& generator, bool parallel) { +Status MutableLiteralBase::PopulateInternal(const FnType& generator, + bool parallel) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); @@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) { return Status::OK(); } template -Status Literal::Populate(const FnType& generator) { +Status MutableLiteralBase::Populate(const FnType& generator) { return PopulateInternal(generator, /*parallel=*/false); } template -Status Literal::PopulateParallel(const FnType& generator) { +Status MutableLiteralBase::PopulateParallel(const FnType& generator) { return PopulateInternal(generator, /*parallel=*/true); } template -void Literal::PopulateWithValue(NativeT value) { +void MutableLiteralBase::PopulateWithValue(NativeT value) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 548fbe8a83a3797aa8ac32dc1f6c085fc0100197..5d33df7d40bf3bfcc8012ce1129d532b34555344 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -34,9 +34,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::Printf; using tensorflow::strings::StrCat; namespace xla { diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index fed0e58e66a04df2ff9554cb0dd0053b7c669803..69ef4f7a2f3ea559a334a11cbe8392b610742bab 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -134,8 +134,7 @@ void MetricTableReport::AppendHeader() { void MetricTableReport::AppendCategoryTable() { const std::vector categories = MakeCategories(&entries_); - AppendLine("********** categories table **********"); - AppendLine("The left hand side numbers are ", metric_name_, "."); + AppendLine("********** categories table for ", metric_name_, " **********"); AppendLine(); double metric_sum = UnaccountedMetric(); @@ -185,8 +184,8 @@ void MetricTableReport::AppendCategoryTable() { } void MetricTableReport::AppendEntryTable() { - AppendLine("********** ", entry_name_, " table **********"); - AppendLine("The left hand side numbers are ", metric_name_, "."); + AppendLine("********** ", entry_name_, " table for ", metric_name_, + " **********"); AppendLine(); double metric_sum = UnaccountedMetric(); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index fe346f9956adaed8d0127e92517b2cd32be05105..c8f2d65c223ccfe20862954c224d016cca421812 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -53,9 +53,9 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 8aefc4cd5e428a82e1e425849ced384c256055ab..8246f76d3443d58f4174cc4f86100f54d6b46928 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" @@ -624,6 +624,7 @@ _FORWARD_BINOP(ShiftRightArithmetic) _FORWARD_BINOP(ShiftRightLogical) _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) +_FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -658,6 +659,9 @@ _FORWARD_UNOP(Asinh) _FORWARD_UNOP(Atanh) _FORWARD_UNOP(Cosh) _FORWARD_UNOP(Sinh) +_FORWARD_UNOP(Real) +_FORWARD_UNOP(Imag) +_FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index dd9e2fbe7299c39389148851d566ecff503eeb71..a568c24c6376e1fe17f5e5a4f6626bf0970985a3 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -341,6 +341,7 @@ class LocalComputationBuilder { _FORWARD_BINOP(ShiftRightLogical) _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) + _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -375,6 +376,9 @@ class LocalComputationBuilder { _FORWARD_UNOP(Atanh) _FORWARD_UNOP(Cosh) _FORWARD_UNOP(Sinh) + _FORWARD_UNOP(Real) + _FORWARD_UNOP(Imag) + _FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 9b8b0aa7f28e64f434bb24f88a3a9cbe177f8a78..5d5a955bfee35b38a61b9a9f792c1b31259ce044 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -1029,6 +1029,10 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Atanh; %unignore xla::swig::LocalComputationBuilder::Cosh; %unignore xla::swig::LocalComputationBuilder::Sinh; +%unignore xla::swig::LocalComputationBuilder::Real; +%unignore xla::swig::LocalComputationBuilder::Imag; +%unignore xla::swig::LocalComputationBuilder::Conj; +%unignore xla::swig::LocalComputationBuilder::Complex; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 71351abd593d45fb5080112438a91df368eee173..6f665faf61b25b23a32ce4d0a012543ba18d7e64 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -50,6 +50,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT32; case F64: return NPY_FLOAT64; + case C64: + return NPY_COMPLEX64; case TUPLE: return NPY_OBJECT; default: @@ -83,6 +85,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F32; case NPY_FLOAT64: return F64; + case NPY_COMPLEX64: + return C64; case NPY_OBJECT: return TUPLE; default: @@ -104,6 +108,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT16: case NPY_FLOAT32: case NPY_FLOAT64: + case NPY_COMPLEX64: case NPY_OBJECT: return true; default: @@ -425,6 +430,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_FLOAT64: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_COMPLEX64: + CopyNumpyArrayToLiteral(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -462,6 +470,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_FLOAT64: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_COMPLEX64: + CopyLiteralToNumpyArray(literal, py_array); + break; default: LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; } diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c0105b385b02e13b360ad1fb5af734d2209a92c2..a2c6fc344d192265d536ef7e23ad5c6d7c847014 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -120,6 +120,9 @@ _UNARY_OPS = [ 'Atanh', 'Cosh', 'Sinh', + 'Real', + 'Imag', + 'Conj', ] _BINARY_OPS = [ @@ -144,6 +147,7 @@ _BINARY_OPS = [ 'ShiftRightArithmetic', 'ShiftRightLogical', 'Atan2', + 'Complex', ] diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD index 8999cda5ef852d1246bea45a3312575ec1ac0721..d790c4db6c466a2bf4d2cf30365749fb901f74a0 100644 --- a/tensorflow/compiler/xla/python_api/BUILD +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -10,6 +10,8 @@ py_library( srcs = ["types.py"], deps = [ "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform", "//third_party/py/numpy", ], ) diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py index b60f8dce92ace1b2c682374a2605b3a477936bbc..57dfce3971b829d2a3052d347e5d2d322db0c841 100644 --- a/tensorflow/compiler/xla/python_api/types.py +++ b/tensorflow/compiler/xla/python_api/types.py @@ -20,9 +20,10 @@ from __future__ import print_function import collections -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes # Records corresponsence between a XLA primitive type and Python/Numpy types. # @@ -40,76 +41,82 @@ TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [ # Maps from XLA primitive types to TypeConversionRecord. MAP_XLA_TYPE_TO_RECORD = { + xla_data_pb2.BF16: + TypeConversionRecord( + primitive_type=xla_data_pb2.BF16, + numpy_dtype=dtypes.bfloat16.as_numpy_dtype, + literal_field_name='bf16s', + literal_field_type=float), xla_data_pb2.F16: TypeConversionRecord( primitive_type=xla_data_pb2.F16, - numpy_dtype=np.float16, + numpy_dtype=_np.float16, literal_field_name='f16s', literal_field_type=float), xla_data_pb2.F32: TypeConversionRecord( primitive_type=xla_data_pb2.F32, - numpy_dtype=np.float32, + numpy_dtype=_np.float32, literal_field_name='f32s', literal_field_type=float), xla_data_pb2.F64: TypeConversionRecord( primitive_type=xla_data_pb2.F64, - numpy_dtype=np.float64, + numpy_dtype=_np.float64, literal_field_name='f64s', literal_field_type=float), xla_data_pb2.S8: TypeConversionRecord( primitive_type=xla_data_pb2.S8, - numpy_dtype=np.int8, + numpy_dtype=_np.int8, literal_field_name='s8s', literal_field_type=int), xla_data_pb2.S16: TypeConversionRecord( primitive_type=xla_data_pb2.S16, - numpy_dtype=np.int16, + numpy_dtype=_np.int16, literal_field_name='s16s', literal_field_type=int), xla_data_pb2.S32: TypeConversionRecord( primitive_type=xla_data_pb2.S32, - numpy_dtype=np.int32, + numpy_dtype=_np.int32, literal_field_name='s32s', literal_field_type=int), xla_data_pb2.S64: TypeConversionRecord( primitive_type=xla_data_pb2.S64, - numpy_dtype=np.int64, + numpy_dtype=_np.int64, literal_field_name='s64s', literal_field_type=int), xla_data_pb2.U8: TypeConversionRecord( primitive_type=xla_data_pb2.U8, - numpy_dtype=np.uint8, + numpy_dtype=_np.uint8, literal_field_name='s8s', literal_field_type=int), xla_data_pb2.U16: TypeConversionRecord( primitive_type=xla_data_pb2.U16, - numpy_dtype=np.uint16, + numpy_dtype=_np.uint16, literal_field_name='s16s', literal_field_type=int), xla_data_pb2.U32: TypeConversionRecord( primitive_type=xla_data_pb2.U32, - numpy_dtype=np.uint32, + numpy_dtype=_np.uint32, literal_field_name='s32s', literal_field_type=int), xla_data_pb2.U64: TypeConversionRecord( primitive_type=xla_data_pb2.U64, - numpy_dtype=np.uint64, + numpy_dtype=_np.uint64, literal_field_name='s64s', literal_field_type=int), xla_data_pb2.PRED: TypeConversionRecord( primitive_type=xla_data_pb2.PRED, - numpy_dtype=np.bool, + numpy_dtype=_np.bool, literal_field_name='preds', literal_field_type=bool) } @@ -119,6 +126,6 @@ MAP_XLA_TYPE_TO_RECORD = { # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. MAP_DTYPE_TO_RECORD = { - str(np.dtype(record.numpy_dtype)): record + str(_np.dtype(record.numpy_dtype)): record for record in MAP_XLA_TYPE_TO_RECORD.values() } diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py index b040098c294ffaae92b72f678947f99289239314..757e41a78ad2b57d2ef6e1f3055160be22c7b3ed 100644 --- a/tensorflow/compiler/xla/python_api/xla_literal.py +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types @@ -35,7 +35,7 @@ def ConvertLiteralToNumpyArray(literal): type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type] if not literal.shape.dimensions: - return np.array( + return _np.array( getattr(literal, type_record.literal_field_name)[0], type_record.numpy_dtype) else: @@ -54,7 +54,7 @@ def ConvertLiteralToNumpyArray(literal): numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') else: raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) - ndarray = np.array( + ndarray = _np.array( getattr(literal, type_record.literal_field_name), copy=False, dtype=type_record.numpy_dtype) @@ -69,11 +69,11 @@ def _ConvertNumpyArrayToLiteral(ndarray): if ndarray.ndim == 0: getattr(literal, type_record.literal_field_name).append( - np.asscalar(ndarray.astype(type_record.literal_field_type))) + _np.asscalar(ndarray.astype(type_record.literal_field_type))) else: # Ndarrays with boolean dtypes need special type conversion with protobufs - if ndarray.dtype in {np.bool_, np.dtype('bool')}: - for element in np.nditer(ndarray): + if ndarray.dtype in {_np.bool_, _np.dtype('bool')}: + for element in _np.nditer(ndarray): getattr(literal, type_record.literal_field_name).append( type_record.literal_field_type(element)) else: diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index 6af28958035bbb03e7e1dbb0d0c7bb2c2f25b96d..f158f6b2410352432445f669155aff0af5526abf 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types @@ -111,7 +111,7 @@ def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name # Set the shape's layout based on the ordering of ndarray. # Numpy arrays come in two orders: Fortran (column-major) and C (row-major). - if np.isfortran(ndarray): + if _np.isfortran(ndarray): # Column-major layout. This corresponds to a "dimension order is # minor-to-major" layout in XLA. layout = range(ndarray.ndim) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 6397f1f47915aaa559beda467c26c66795c98f60..a803520876952a0ab67ecb827b1f256c915335f9 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 0b1cec1925d4424db086f8a3f62c91ede090189c..44b22a5586dee3f7dd8ea0edbf9deb2090986ac8 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -56,7 +56,7 @@ tf_cc_test( ":grpc_stub", "//tensorflow:grpc++", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 90efee50b4f19056fac8ef1b341b48175903ff83..67886761813f0bb45a600661b017be91ffeade73 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index cba7883fde855ec6b1cd08184eab4a19afa5722d..1b93d72a3e2ad64f54407667b98fc5ef86a13d02 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -256,7 +256,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -564,7 +564,7 @@ cc_library( ":computation_placer", ":device_memory_allocator", ":platform_util", - ":pool", + ":stream_pool", ":transfer_manager", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -598,6 +598,7 @@ cc_library( ":hlo_proto_util", ":platform_util", ":source_map_util", + ":stream_pool", ":transfer_manager", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", @@ -642,7 +643,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -751,8 +752,8 @@ cc_library( ":hlo_execution_profile", ":hlo_graph_dumper", ":hlo_proto", - ":pool", ":shaped_buffer", + ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -838,7 +839,7 @@ cc_library( hdrs = ["execution_tracker.h"], deps = [ ":backend", - ":pool", + ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -946,7 +947,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -1384,6 +1384,18 @@ tf_cc_test( ], ) +cc_library( + name = "while_loop_analysis", + srcs = ["while_loop_analysis.cc"], + hdrs = ["while_loop_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + "//tensorflow/compiler/xla:literal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1391,8 +1403,8 @@ cc_library( deps = [ ":call_inliner", ":hlo", - ":hlo_evaluator", ":hlo_pass", + ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], @@ -1663,8 +1675,8 @@ tf_cc_test( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2670,7 +2682,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -2707,7 +2719,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2715,21 +2727,25 @@ tf_cc_test( ) cc_library( - name = "pool", - hdrs = ["pool.h"], + name = "stream_pool", + srcs = ["stream_pool.cc"], + hdrs = ["stream_pool.h"], deps = [ + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) tf_cc_test( - name = "pool_test", - srcs = ["pool_test.cc"], + name = "stream_pool_test", + srcs = ["stream_pool_test.cc"], deps = [ - ":pool", + ":stream_pool", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:stream_executor_no_cuda", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 26a8a67601e82ddb21918b90f6d41b1d961bf115..946ef6f0d6b9025b84c4b9341f4ec600465d4b1e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -150,6 +150,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleTranspose(HloInstruction* transpose) override; Status HandleSubtract(HloInstruction* sub) override; @@ -1769,13 +1771,12 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { Status AlgebraicSimplifierVisitor::HandleDynamicSlice( HloInstruction* dynamic_slice) { auto operand = dynamic_slice->mutable_operand(0); - auto start_indices = dynamic_slice->operand(1); if (ShapeUtil::IsScalar(dynamic_slice->shape())) { return ReplaceInstruction(dynamic_slice, operand); } - // DynamicSlice where operand has the same size as the output and - // start_indices are all zero is simply equal to operand. - if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) { + // DynamicSlice where operand has the same size as the output is simply equal + // to operand. + if (SameShape(operand, dynamic_slice)) { return ReplaceInstruction(dynamic_slice, operand); } return Status::OK(); @@ -1784,20 +1785,10 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { auto update = dynamic_update_slice->mutable_operand(1); - auto start_indices = dynamic_update_slice->operand(2); - // DynamicUpdateSlice on a scalar just passes through the update argument. - if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { - return ReplaceInstruction(dynamic_update_slice, update); - } - // DynamicUpdateSlice where operand and update have the same size and - // start_indices are all zero is simply equal to update. - // - // (We require start_indices to be all zero because we want this optimization - // not to affect the visible behavior of this op even when the indices are out - // of range. Currently dynamic-update-slice wraps out-of-range indices, so - // we can only remove the op if its indices never wrap.) - if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { + // DynamicUpdateSlice where operand and update have the same size is simply + // equal to update. + if (SameShape(dynamic_update_slice, update)) { return ReplaceInstruction(dynamic_update_slice, update); } @@ -2116,6 +2107,21 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( /*reduce_computation=*/function)); } +Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { + auto operand = sort->mutable_operand(0); + int64 dimension_to_sort = sort->dimensions(0); + if (ShapeUtil::IsZeroElementArray(operand->shape()) || + operand->shape().dimensions(dimension_to_sort) <= 1) { + if (sort->operand_count() == 1) { + return ReplaceInstruction(sort, operand); + } + // If it is key/value sort, the output of sort is a tuple. + return ReplaceWithNewInstruction( + sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index ddf0a513c0164f479fc7e1ebbb597e16aa116215..862cbeeba6b82e1f24a6616b3237dc47d022e9af 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1941,6 +1941,40 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); } +TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), keys); +} + +TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0}); + Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); +} + TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { struct ConvTestOptions { int in_batch = 10; @@ -1972,7 +2006,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { // Builds a convolution from and runs algebraic simplification on // the computation. Returns a string description of the result of // simplification. - auto build_and_simplify = [&options, this]() -> string { + auto build_and_simplify = [&options]() -> string { HloComputation::Builder b(TestName()); Window window; @@ -2489,8 +2523,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { shape, builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({0, 0, 0}))), + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), /*slice_sizes=*/{10, 100, 1000})); auto computation = module().AddEntryComputation(builder.Build()); @@ -2523,8 +2557,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction( HloInstruction::CreateParameter(2, slice_shape, "to_update")), slice, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({0, 0, 0}))))); + builder.AddInstruction(HloInstruction::CreateParameter( + 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 95b4cb6d2e694063b648b264bd2454ae0a5469ff..51ebc4763b612884a4453edec5711f78c4006fc3 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -109,11 +109,11 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { ResolveInternal(data)); for (const auto& shaped_buffer : replicated_buffers) { std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); + ShapeUtil::ForEachSubshape( + shaped_buffer->on_device_shape(), + [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) { + shape_indices.push_back(index); + }); for (const ShapeIndex& index : shape_indices) { TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), shaped_buffer->device_ordinal())); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 349b32451a697dbd6804b44cd1a36419c753bb14..d12be3e007fe0b16ac850d64521f0025d481b5d2 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -96,24 +96,19 @@ Backend::CreateDefaultBackend() { return CreateBackend(backend_options); } -StatusOr Backend::BorrowStream(int device_ordinal) { - TF_ASSIGN_OR_RETURN(auto exec, stream_executor(device_ordinal)); - return BorrowStream(exec); +StatusOr Backend::BorrowStream(int device_ordinal) { + TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); + return BorrowStream(executor); } -StatusOr Backend::BorrowStream( - se::StreamExecutor* executor) { +StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(mu_); if (0 == stream_pools_.count(executor)) { stream_pools_.emplace(std::piecewise_construct, std::forward_as_tuple(executor), - std::forward_as_tuple([executor]() { - auto stream = MakeUnique(executor); - stream->Init(); - return stream; - })); + std::forward_as_tuple()); } - return stream_pools_.at(executor).Allocate(); + return stream_pools_.at(executor).BorrowStream(executor); } Backend::Backend( diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 6546602473e3381cf13879ddebd05d34d1f7a055..1bc3796fa48c1627538474d04ef5358ba64dfce9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -63,11 +63,9 @@ class BackendOptions { // // It also offers a pooling API for creation/use of initialized streams: // -// StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie(); +// StreamPool::Ptr stream = backend->BorrowStream().ConsumeValueOrDie(); class Backend { public: - using StreamPtr = Pool::SmartPtr; - // Creates a new backend. static StatusOr> CreateBackend( const BackendOptions& options); @@ -114,13 +112,13 @@ class Backend { // Borrows a stream for use by the caller, either by grabbing it from an // internal pool, or by constructing/initializating it, and returns the result // to the caller. - StatusOr BorrowStream(int device_ordinal); - StatusOr BorrowStream(se::StreamExecutor* executor); + StatusOr BorrowStream(int device_ordinal); + StatusOr BorrowStream(se::StreamExecutor* executor); // Returns a function to borrow a stream, as `BorrowStream` above does. // Purely for convenience, the caller could rather make this anonymous // function itself. - std::function(int)> StreamBorrower() { + std::function(int)> StreamBorrower() { return [this](int device_ordinal) { return BorrowStream(device_ordinal); }; } @@ -169,7 +167,7 @@ class Backend { tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map> stream_pools_ GUARDED_BY(mu_); + std::map stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 32f785a70adf0e7ea3ce281f7ff73224be8d424e..a725351462809e5b670bbf1d79d2dded87e54f07 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -137,9 +137,9 @@ ENTRY entry { if (instruction->opcode() == HloOpcode::kParameter) { continue; } - ASSERT_TRUE(instruction->has_sharding()); - TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice()); - EXPECT_EQ(device, 1); + auto device = instruction->sharding_unique_device(); + ASSERT_TRUE(device); + EXPECT_EQ(*device, 1); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index f7b4c1405dbc8719d8fba5476e6e41d2921ea877..7cf05ca443c00c3b40eeb7d756cf216b45c45c39 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,7 +235,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_group_ids=*/{}, /*barrier=*/"")); + sum, /*replica_group_ids=*/{}, /*barrier=*/"", + /*all_reduce_id=*/tensorflow::gtl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 14c54ddd135af024327f63418b410da1ed3c4fd4..16e99b57220cc185fbfaa75d30a0de709cf61ee7 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -34,8 +34,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum which can have a tuple output. + // Special handling for cross-replica-sum and sort which can have a tuple + // output. Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleSort(HloInstruction* sort) override; static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { @@ -49,6 +51,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // conversions between F32 and BF16 to make it supported. Status HandleInstruction(HloInstruction* hlo); + // Handle instructions with tuple outputs by examining each output + // independently. + Status HandleMultipleOutputs(HloInstruction* hlo); + // Inserts a conversion HLO that changes the given HLO's output type. Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, HloComputation* computation); @@ -148,22 +154,35 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( HloInstruction* crs) { if (!ShapeUtil::IsTuple(crs->shape())) { return HandleInstruction(crs); + } else { + return HandleMultipleOutputs(crs); } +} + +Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { + if (!ShapeUtil::IsTuple(sort->shape())) { + return HandleInstruction(sort); + } else { + return HandleMultipleOutputs(sort); + } +} - std::vector operand_types(crs->operand_count()); - std::vector output_types(crs->operand_count()); +Status BFloat16NormalizationVisitor::HandleMultipleOutputs( + HloInstruction* hlo) { + std::vector operand_types(hlo->operand_count()); + std::vector output_types(hlo->operand_count()); int64 f32_count = 0; int64 bf16_count = 0; bool has_unsupported_bf16_operand = false; bool has_unsupported_bf16_output = false; - for (int64 i = 0; i < crs->operand_count(); ++i) { - operand_types[i] = crs->operand(i)->shape().element_type(); - output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + operand_types[i] = hlo->operand(i)->shape().element_type(); + output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type(); if (operand_types[i] == F32) { f32_count += 1; } else if (operand_types[i] == BF16) { bf16_count += 1; - if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { has_unsupported_bf16_operand = true; } } @@ -171,7 +190,7 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( f32_count += 1; } else if (output_types[i] == BF16) { bf16_count += 1; - if (!bfloat16_support_->SupportsBF16Output(*crs)) { + if (!bfloat16_support_->SupportsBF16Output(*hlo)) { has_unsupported_bf16_output = true; } } @@ -185,43 +204,43 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( if (operand_types[i] != BF16) { return false; } - if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { return true; } - if (bfloat16_support_->SupportsMixedPrecisions(*crs)) { + if (bfloat16_support_->SupportsMixedPrecisions(*hlo)) { return false; } return has_unsupported_bf16_operand || has_unsupported_bf16_output || f32_count > 0; }; - for (int64 i = 0; i < crs->operand_count(); ++i) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { if (should_convert_operand(i)) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); f32_count += 1; bf16_count -= 1; } } if (!has_unsupported_bf16_output && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 || + (bfloat16_support_->SupportsMixedPrecisions(*hlo) || f32_count == 0 || bf16_count == 0)) { return Status::OK(); } - std::vector materialized_users = crs->users(); - std::vector output_elements(crs->operand_count()); - auto original_shape = crs->shape(); - for (int64 i = 0; i < crs->operand_count(); ++i) { - auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}); + std::vector materialized_users = hlo->users(); + std::vector output_elements(hlo->operand_count()); + auto original_shape = hlo->shape(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), {i}); if (output_types[i] != BF16) { output_elements[i] = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); continue; } subshape->set_element_type(F32); auto gte = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); output_elements[i] = computation_->AddInstruction(HloInstruction::CreateConvert( ShapeUtil::ChangeElementType(*subshape, BF16), gte)); @@ -229,11 +248,11 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); - // Use the crs' shape temporarily, in order to pass checks in + // Use the hlo' shape temporarily, in order to pass checks in // ReplaceUseWith. - *tuple->mutable_shape() = crs->shape(); + *tuple->mutable_shape() = hlo->shape(); for (auto* user : materialized_users) { - TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple)); + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); } *tuple->mutable_shape() = original_shape; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 830f26422bdc2b3bd789e7d5926bcebac815d34a..f9f1f64998f5b925102dc238941897ff6d441b3f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -251,7 +251,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_group_ids=*/{}, /*barrier=*/"")); + /*replica_group_ids=*/{}, /*barrier=*/"", + /*all_reduce_id=*/tensorflow::gtl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); @@ -265,6 +266,33 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); } +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); + Shape s32_shape = ShapeUtil::MakeShape(BF16, {1024}); + + HloInstruction* key = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "key")); + HloInstruction* value = builder.AddInstruction( + HloInstruction::CreateParameter(1, s32_shape, "value")); + + HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), gte); + EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); +} + // Tests that the normalization should not cause unsupported mixed precision due // to resolving unsupported BF16 operand. TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index b21c83a07f69d6ec93cf9305802e4d3af2783bdc..2fb401c4289728f3f59538464c5b8ad49957985b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -215,7 +215,12 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { return false; } - if (ValueTypeAfterChange(value) == BF16) { + // We use the original type for the value because we are going to examine + // the uses of it, instead of the value itself. If ValueTypeAfterChange() + // were used, it would cause problems when there are aliasing buffers, i.e., + // ResolveInconsistencyOfAliasingBuffers() would fail to revert the + // tentative change to BF16 even if the uses require F32. + if (value->shape().element_type() == BF16) { continue; } for (const HloUse& use : value->uses()) { @@ -566,6 +571,9 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( } visited_computations->insert(visited_in_while.begin(), visited_in_while.end()); + } else if (hlo->opcode() == HloOpcode::kFusion) { + ResolveInconsistencyOfAliasingBuffersHelper( + hlo->fused_instructions_computation(), visited_computations); } } // Now adjust parameters of called computations. @@ -769,8 +777,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // propagation in reverse topological order. for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { - if ((*comp_it)->IsFusionComputation()) { - // Fusion computations are handled when visiting the fusion instruction. + if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) { continue; } auto insts = (*comp_it)->MakeInstructionPostOrder(); @@ -778,6 +785,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/true); } + computations_visited_in_backward_pass_.insert(*comp_it); } // It's possible that an instruction does not define a buffer, but the diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index aeafb25ad7215ea3d297e4a8bf7e1ba72d33d528..69b654d30e42b1ed69304206f09120e86831d468 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -508,6 +508,63 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { EXPECT_FALSE(OutputsBF16(dot)); } +// Tests that if the while condition prevents using BF16, no changes should be +// made to the while body and thus the fusion node inside it. +TEST_F(BFloat16PropagationTest, + ConditionPreventsPropagationForFusionInsideWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, shape, "cond_param")); + builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, shape, "body_param")); + auto body_transpose = builder_body.AddInstruction( + HloInstruction::CreateTranspose(shape, body_param, {0, 1})); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1})); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f)); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(shape, cond, body, add)); + + auto dot = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_FALSE(OutputsBF16(add)); + EXPECT_FALSE(OutputsBF16(body_fusion)); + EXPECT_FALSE(OutputsBF16(body_param)); + EXPECT_FALSE(OutputsBF16(body_transpose)); + EXPECT_FALSE(OutputsBF16(a_f)); +} + // Tests that BF16 is propagated properly through while computations with // tuple-shaped input/output. TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { @@ -553,10 +610,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( + auto body_dot1 = builder_body.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot2 = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_transpose = builder_body.AddInstruction( + HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( - HloInstruction::CreateTuple({body_dot, body_rhs})); + HloInstruction::CreateTuple({body_dot1, body_transpose})); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( @@ -575,9 +636,11 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); EXPECT_FALSE(OutputsBF16(rhs)); - EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_dot1)); EXPECT_TRUE(OutputsBF16(body_lhs)); EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_FALSE(OutputsBF16(body_dot2)); + EXPECT_FALSE(OutputsBF16(body_transpose)); EXPECT_TRUE(OutputsBF16(cond_lhs)); EXPECT_FALSE(OutputsBF16(cond_rhs)); EXPECT_TRUE(OutputsBF16(add0)); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index b4c7cf0dd8d3520077f2131b65192865d3701602..118a11c8de3c06d240079723f0a5db314cfcace5 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -817,8 +817,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } Status BufferAssigner::AssignBuffersForComputation( - const HloComputation* computation, const DebugOptions& debug_options, - bool is_thread_local, + const HloComputation* computation, bool is_thread_local, const FlatSet& colocated_buffers, const FlatSet& colocated_allocations, FlatMap>* @@ -878,8 +877,8 @@ Status BufferAssigner::AssignBuffersForComputation( // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [this, has_sequential_order, &liveness, &post_order_position, - assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + [has_sequential_order, &liveness, &post_order_position, assignment]( + const LogicalBuffer* a, const LogicalBuffer* b) { // Primary sort is by decreasing buffer size. const int64 a_size = assignment->buffer_size_(*a); const int64 b_size = assignment->buffer_size_(*b); @@ -1342,11 +1341,25 @@ BufferAssigner::MergeColocatedBufferSets( auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, &is_entry_parameter](int64 i, int64 j) { - // Do not merge if one of the sets includes live outs or entry parameters. + // Do not merge if one of the sets includes live outs, entry parameters or + // constants. + // + // Buffer liveness does not report the correct live range for entry + // parameter and live out buffers so we have to special case them here. On + // backends that support constant buffer allocations, constant buffers are + // assigned globals in readonly storage so we can't merge colocated buffer + // sets containing constants with colocated buffer sets containing writing + // instructions or other constants. + // + // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to + // the caller of the executable so we can't write to entry parameters + // either, and the argument for not merging constants also applies to entry + // parameters. for (int64 key : {i, j}) { for (auto& buffer : colocated_buffer_sets[key]) { if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer)) { + is_entry_parameter(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kConstant) { return true; } } @@ -1428,9 +1441,9 @@ void BufferAssigner::BuildColocatedBufferSets( const HloInstruction* while_hlo = instruction; ShapeUtil::ForEachSubshape( while_hlo->shape(), - [this, while_hlo, &points_to_analysis, &buffer_liveness, - buffer_size, computation, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { + [this, while_hlo, &points_to_analysis, buffer_size, + colocated_buffer_sets](const Shape& /*subshape*/, + const ShapeIndex& index) { std::vector colocated_set; // Add while.init. AddBufferToColocatedSet(while_hlo->operand(0), index, @@ -1664,7 +1677,7 @@ StatusOr> BufferAssigner::CreateAssignment( buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, module->config().debug_options(), + computation, /*is_thread_local=*/false, colocated_buffers, colocated_allocations, &buffers_to_assign_sequentially, assignment.get())); } @@ -1685,7 +1698,7 @@ StatusOr> BufferAssigner::CreateAssignment( continue; } TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, module->config().debug_options(), + computation, /*is_thread_local=*/true, colocated_buffers, colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 4fcf1fc73defcfba16c33224bd9c785675674408..94495290c131e22392079dc2d0237d990b646d3e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -543,8 +542,7 @@ class BufferAssigner { // true, then all assigned buffers have the is_thread_local flag set to // true. Status AssignBuffersForComputation( - const HloComputation* computation, const DebugOptions& debug_options, - bool is_thread_local, + const HloComputation* computation, bool is_thread_local, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index dea855d39ad759e7c3c13fcd3ccf06e4a1089df7..eccb146a0d7d628870be179a540d9750df3fe41c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1923,6 +1923,74 @@ ENTRY %test_module { EXPECT_NE(slice_param, slice_while1); } +TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + const char* module_str = R"( +HloModule test_module + +%cond.v0 { + %param = s32[] parameter(0) + ROOT %constant = pred[] constant(true) +} + +%cond.v1 { + %param.0 = s32[] parameter(0) + ROOT %constant.0 = pred[] constant(true) +} + +%body.v0 { + ROOT %param.1 = s32[] parameter(0) +} + +%body.v1 { + %param.2 = s32[] parameter(0) + ROOT add = s32[] add(%param.2, %param.2) +} + +ENTRY %test_module { + %constant.42 = s32[] constant(42) + %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0 + %mul = s32[] multiply(%while.0, %while.0) + %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1 + ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Get the instructions in the module. + const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* constant = + module->entry_computation()->GetInstructionWithName("constant.42"); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + const HloInstruction* while1 = bcast->operand(0); + ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); + const HloInstruction* while0 = while1->operand(0)->operand(0); + ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); + + // Run buffer assignment. + auto assignment = RunBufferAssignment(module.get()); + TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, + assignment->GetUniqueSlice(constant, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + + // The constant slice is part of the while0's colocation set (init value), but + // not merged into the while1's colocation set. + EXPECT_EQ(slice_constant, slice_while0); + EXPECT_NE(slice_constant, slice_while1); +} + // Tests that the colocated buffers for while instructions are properly assigned // during buffer assignment such that the result tuple elements are not assigned // to the same buffer. diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index d26486fcfe0b1bc51867de5113cc5e42a0d7b4f0..187ce568cbb6c6666e978b8c8114262313c70ba5 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -29,9 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + namespace xla { Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { @@ -71,6 +75,19 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { return std::move(assignment); } +string DeviceAssignment::ToString() const { + string output = StrCat("Computations: ", computation_count(), + " Replicas: ", replica_count(), "\n"); + for (int computation = 0; computation < computation_count(); ++computation) { + StrAppend(&output, "Computation ", computation, ": "); + for (int replica = 0; replica < replica_count(); ++replica) { + StrAppend(&output, operator()(replica, computation), " "); + } + StrAppend(&output, "\n"); + } + return output; +} + StatusOr ComputationPlacer::DeviceId(int replica, int computation, int replica_count, int computation_count) { diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 737d00e93ecb51a9bd544bbcbe99d93374d108fb..c899ffb9dc562426ef14c0d414469c04debeec70 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -55,6 +55,8 @@ class DeviceAssignment : public Array2D { // due to a StatusOr of an incomplete type (DeviceAssignment). static StatusOr> Deserialize( const DeviceAssignmentProto& proto); + + string ToString() const; }; // A generic implementation of the XLA computation placer, which assigns device diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ace9f96cfb70e487476dd42b14fa241b70f80634..504b61d134a0099d055d0266408e1dfb94af5b2a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -252,6 +252,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", @@ -363,8 +364,8 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -444,6 +445,7 @@ cc_library( deps = [ ":vector_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:transform_utils", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 6a7eb85e3baec3517b8f3ddef6a8dcfae9c9e614..128eea4828b5e514b2ba6b398898e4a5d228e746 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -156,9 +156,26 @@ std::unique_ptr CompilerFunctor::operator()( target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream); codegen_passes.run(module); - // Construct ObjectFile from machine code buffer. - return std::unique_ptr( + std::unique_ptr memory_buffer( new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); + + if (VLOG_IS_ON(2)) { + llvm::Expected> obj_file = + llvm::object::ObjectFile::createObjectFile(*memory_buffer); + if (obj_file) { + StatusOr disasm_result = + disassembler_->DisassembleObjectFile(*obj_file.get()); + if (disasm_result.ok()) { + XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text); + } else { + LOG(WARNING) << "Could not disassemble object file!"; + } + } else { + LOG(WARNING) << "Could convert memory buffer to object file!"; + } + } + + return memory_buffer; } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 29fa29d33ad62a76191cef2de22ccc094b0cf35b..8cbe9a1b0d5b0553b1121d544196412f36f8ce43 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -562,7 +562,9 @@ StatusOr> CpuCompiler::RunBackend( BufferAssigner::Run( module.get(), xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -584,6 +586,8 @@ StatusOr> CpuCompiler::RunBackend( std::move(computation_to_profile_idx), &target_machine_features); + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { @@ -747,7 +751,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, BufferAssigner::Run( module, xla::MakeUnique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -776,6 +782,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -831,17 +840,29 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, BufferSizes buffer_sizes; for (const BufferAllocation& allocation : assignment->Allocations()) { - // Callers don't need to allocate temporary buffers for parameters. - if (allocation.is_entry_computation_parameter()) { - buffer_sizes.push_back(-1); - continue; - } // Callers don't need to allocate anything for thread-local temporary // buffers. They are lowered to allocas. if (allocation.is_thread_local()) { buffer_sizes.push_back(-1); continue; } + + // Callers don't need to allocate anything for constant buffers. They are + // lowered to globals. + if (allocation.is_constant()) { + buffer_sizes.push_back(-1); + continue; + } + + // Callers don't need to allocate anything for entry computation buffers, + // but they do need to stash the pointer to the entry computation buffer + // in the temp buffer table. See the comment on + // XlaCompiledCpuFunction::StaticData::temp_sizes. + if (allocation.is_entry_computation_parameter()) { + buffer_sizes.push_back(-allocation.parameter_number() - 2); + continue; + } + buffer_sizes.push_back(allocation.size()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 1093559892ddb9c238fd9c1f7e3d419ec7022776..946f5124b87bc011df4f3553077dbb37a3333ed2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -69,12 +69,19 @@ CpuExecutable::CpuExecutable( // guarded by the mutex. compute_function_ = reinterpret_cast(cantFail(sym.getAddress())); + VLOG(1) << "compute_function_ at address " + << reinterpret_cast(compute_function_); } -Status CpuExecutable::AllocateBuffers( +StatusOr, + std::vector>> +CpuExecutable::CreateTempArray( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { - CHECK_EQ(buffers->size(), assignment_->Allocations().size()); + tensorflow::gtl::ArraySlice arguments) { + std::vector unowning_buffers( + assignment_->Allocations().size()); + std::vector owning_buffers( + assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); @@ -84,44 +91,51 @@ Status CpuExecutable::AllocateBuffers( VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { + unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer( + allocation.param_shape_index()); VLOG(3) << "allocation #" << i << " is a parameter"; continue; } + if (allocation.is_constant()) { + VLOG(3) << "allocation #" << i << " is a constant"; + continue; + } + if (allocation.is_thread_local()) { VLOG(3) << "buffer #" << i << " is thread-local"; continue; } int64 buffer_size = allocation.size(); - if (!(*buffers)[i].is_null()) { + if (!owning_buffers[i].is_null()) { VLOG(3) << "buffer #" << i << " is in the preallocated result ShapedBuffer"; } else { - TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( - device_ordinal, buffer_size)); + TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); + unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase(); VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << (*buffers)[i].opaque() << "]"; + << owning_buffers[i].opaque() << "]"; } // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size); } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment_->GetUniqueTopLevelOutputSlice()); VLOG(3) << "result index: " << result_slice.index(); - return Status::OK(); + return {{std::move(unowning_buffers), std::move(owning_buffers)}}; } Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: @@ -131,17 +145,11 @@ Status CpuExecutable::ExecuteComputeFunction( // // result: Points at the result. // run_options: the ExecutableRunOptions object. - // args_array: An array of pointers, each of which points to a parameter. - // The size of this array is determined by the function's arity - // (ProgramShape). - // temps_array: An array of pointers, each of which points to a temporary - // buffer the computation needs. The size of this array is - // determined by buffer analysis. + // args_array: null + // temps_array: An array of pointers, containing pointers to temporary buffers + // required by the executable adn pointers to entry computation + // parameters. // - std::vector args_array; - for (const ShapedBuffer* argument : arguments) { - args_array.push_back(argument->root_buffer().opaque()); - } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -164,16 +172,14 @@ Status CpuExecutable::ExecuteComputeFunction( if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[%zu], void* temps[%zu], " + " func(void* result, void* params[null], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters_size); + buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); }; - VLOG(3) << tensorflow::strings::Printf( - " params = [%s]", - tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str()); + VLOG(3) << " params = nullptr"; VLOG(3) << tensorflow::strings::Printf( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); @@ -181,8 +187,8 @@ Status CpuExecutable::ExecuteComputeFunction( profile_counters); } - compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters); + compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), + profile_counters); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -249,21 +255,18 @@ StatusOr CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); + std::vector owning_buffers; std::vector unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } - TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), - arguments, unowning_buffers, - hlo_execution_profile)); + TF_ASSIGN_OR_RETURN( + std::tie(unowning_buffers, owning_buffers), + CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), + arguments)); + + TF_RETURN_IF_ERROR(ExecuteComputeFunction( + &run_options->run_options(), unowning_buffers, hlo_execution_profile)); - return CreateResultShapedBuffer(run_options, &buffers); + return CreateResultShapedBuffer(run_options, &owning_buffers); } StatusOr CpuExecutable::ExecuteAsyncOnStream( @@ -279,17 +282,15 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - + std::vector owning_buffers; std::vector unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } + TF_ASSIGN_OR_RETURN( + std::tie(unowning_buffers, owning_buffers), + CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), + arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &buffers)); + CreateResultShapedBuffer(run_options, &owning_buffers)); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -307,7 +308,6 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( struct AsyncRunTask { CpuExecutable* executable; ServiceExecutableRunOptions run_options; - std::vector arguments; std::vector unowning_buffers; std::shared_ptr> buffers; @@ -315,15 +315,14 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), arguments, unowning_buffers, + &run_options.run_options(), unowning_buffers, /*hlo_execution_profile=*/nullptr)); } }; - host_stream->EnqueueTask(AsyncRunTask{ - this, *run_options, - std::vector(arguments.begin(), arguments.end()), - unowning_buffers, - std::make_shared>(std::move(buffers))}); + host_stream->EnqueueTask( + AsyncRunTask{this, *run_options, std::move(unowning_buffers), + std::make_shared>( + std::move(owning_buffers))}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8dd47bfb865e8a0552542f510d3365cff0d111e0..8af8a5dfec2834678418f069619ba88b01633361 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -85,20 +85,29 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: - // Allocate buffers required for execution and assign them to the elements of - // "buffers". "buffers" should be sized to the number of buffers in buffer - // assignment. Each vector element corresponds to a particular Index. If - // a vector element already contains a non-null DeviceMemoryBase, then no - // buffer is assigned for this element. - Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, - int device_ordinal, - std::vector* buffers); + // Creates an array suitable for passing as the "temps" argument to the JIT + // compiled function pointer. + // + // Returns (unowning_buffers, owning_buffers) where: + // + // - unowning_buffers.data() can be passed as the temps argument as-is and + // includes pointers to the scratch storage required by the computation, + // the live-out buffer into which the result will be written and entry + // computation parameters. + // + // - owning_buffers contains owning pointers to the buffers that were + // allocated by this routine. This routine allocates buffers for temporary + // storage and the live-out buffer into which the computation writes it + // result. + StatusOr, + std::vector>> + CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + tensorflow::gtl::ArraySlice arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 54c52bc08f9c53b8c6898689b18c4cb7f4bdcfd0..639064040f521a9e84bd87c5d05f674204e4d6e2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -92,9 +92,10 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace -void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, - const void* shape, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* +__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, + const void* shape, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "AcquireInfeedBufferForDequeue: " << ShapeString(shape, shape_length); @@ -111,9 +112,11 @@ void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, return buffer->data(); } -void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, + void* buffer_ptr, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " << ShapeString(shape_ptr, shape_length); @@ -125,8 +128,10 @@ void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( std::move(shape)); } -void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* +__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "AcquireOutfeedBufferForPopulation: " << ShapeString(shape_ptr, shape_length); @@ -143,9 +148,11 @@ void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( return buffer->data(); } -void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length, + void* buffer_ptr, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " << ShapeString(shape_ptr, shape_length); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 156166bf2b1ea6d3821da8f67ea2b2eca6825ca6..59bc7e0e16fcc66a010408259a1ccfb2b6bb35fd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -173,7 +173,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { if (!ShapeUtil::IsTuple(literal_shape)) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from @@ -181,18 +181,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( tensorflow::gtl::ArraySlice dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); - *literal = std::move(*LiteralUtil::CreateFromDimensions( - literal_shape.element_type(), dimensions)); - TF_ASSIGN_OR_RETURN(Shape received_shape, - TransferArrayBufferFromOutfeed( - executor, literal->untyped_data(), size)); - TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) + TF_ASSIGN_OR_RETURN( + Shape received_shape, + TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size)); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " << ShapeUtil::HumanString(literal_shape); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); - *literal->mutable_shape_do_not_use() = received_shape; + *literal.mutable_shape_do_not_use() = received_shape; return Status::OK(); } @@ -201,22 +199,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( "Nested tuple outfeeds are not yet implemented on CPU."); } - std::vector> elements; std::vector> buffer_data; for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { const Shape& tuple_element_shape = ShapeUtil::GetTupleElementShape(literal_shape, i); - // Note: OSS build didn't like implicit conversion from - // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( - tensorflow::bit_cast( - tuple_element_shape.dimensions().data()), - tuple_element_shape.dimensions().size()); - auto empty = LiteralUtil::CreateFromDimensions( - tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); - buffer_data.push_back({empty->untyped_data(), size}); - elements.push_back(std::move(empty)); + buffer_data.push_back({literal.untyped_data({i}), size}); } TF_ASSIGN_OR_RETURN(Shape received_shape, @@ -230,11 +218,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( TF_RET_CHECK(GetByteSizeRequirement(literal_shape) == GetByteSizeRequirement(received_shape)); - for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { - *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); - } - *literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements))); - TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); + TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 593575c0fdaddc71cd6bd844fd179096a9fb0fdc..80ef953d532798281c10b7a212b9c4d84a790c27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -41,7 +42,7 @@ class CpuTransferManager : public GenericTransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 1fdeceb8604e0ae31273e0a0e60b360f6c6976ef..645888de783e4025cffd6fa4835e60b84bbd7d99 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -1278,10 +1278,10 @@ Status DotOpEmitter::Emit() { // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), b_); - llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest( - &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs"); - llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest( - &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs"); + llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( + lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( + rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); // Create the loop which does the sum of products reduction. // @@ -1537,36 +1537,6 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( - llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, - int64 reduction_dimension, tensorflow::StringPiece name_suffix) { - // Prepares the dimension list we will use to emit the loop nest. Outermost - // loops are added first. Add loops in major-to-minor order, and skip the - // reduction dimension. - std::vector dimensions; - const Shape& shape = operand_array.GetShape(); - for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) { - int64 dimension = LayoutUtil::Minor(shape.layout(), i); - if (dimension != reduction_dimension) { - dimensions.push_back(dimension); - } - } - - // Create loop nest with one for-loop for each dimension of the - // output. - llvm_ir::IrArray::Index index = - loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); - // Verify every dimension except the reduction dimension was set in the index. - for (int dimension = 0; dimension < index.size(); ++dimension) { - if (dimension == reduction_dimension) { - DCHECK_EQ(nullptr, index[dimension]); - } else { - DCHECK_NE(nullptr, index[dimension]); - } - } - return index; -} - // Return whether the given shape is a matrix with no padding. static bool IsRank2WithNoPadding(const Shape& shape) { return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index c2eeb0a1f92e3df91cd234599fe1cd5fc97c4625..590032fbe907d7ca90bf69b7ccc3170b8efec72e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -88,17 +88,6 @@ class DotOpEmitter { // Emits a call to the CPU runtime to perform the matrix multiply. Status EmitCallToRuntime(); - // Emits a series of nested loops for iterating over an operand array in the - // dot operation. Loops are constructed in major to minor dimension layout - // order. No loop is emitted for the given reduction_dimension. The function - // returns an IrArray index for the given operand_array containing the indvars - // of the loops. All dimensions of the index are filled except for the - // reduction dimension. name_suffix is the string to append to the names of - // LLVM constructs (eg, basic blocks) constructed by this method. - llvm_ir::IrArray::Index EmitOperandArrayLoopNest( - llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, - int64 reduction_dimension, tensorflow::StringPiece name_suffix); - // Represents the dimensions of a matrix-matrix multiply operation. struct MatMultDims { // The number of rows in the LHS. diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index cf955a8add394c204673be0746a451d4edcadc96..c13d36776f94221598338dca4eadf024c0a892df 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -19,6 +19,8 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -117,9 +119,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } - return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), - hlo->to_apply(), operands, - llvm_ir::IrName(hlo)); + return ir_emitter_->EmitElementalMap(*Cast(hlo), + operands, llvm_ir::IrName(hlo)); }; } return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index d4ac35a604fed73e3082d1a6126f398e656677dd..ca645d3f1da18fb26378a10526c27a7d254896e2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -115,6 +116,19 @@ StatusOr IrEmitter::EmitComputation( computation->root_instruction()->outer_dimension_partitions().size(); } + if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) { + TF_ASSIGN_OR_RETURN( + computation_root_allocation_, + assignment_.GetUniqueTopLevelSlice(computation->root_instruction())); + } + + for (const HloInstruction* param : computation->parameter_instructions()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice, + assignment_.GetUniqueTopLevelSlice(param)); + computation_parameter_allocations_[param_slice.allocation()->index()] = + param->parameter_number(); + } + InitializeIrFunction(function_name); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -131,6 +145,8 @@ StatusOr IrEmitter::EmitComputation( // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. compute_function_.reset(); + computation_root_allocation_ = BufferAllocation::Slice(); + computation_parameter_allocations_.clear(); return ir_function; } @@ -175,25 +191,36 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { result_global, IrShapeType(literal.shape())->getPointerTo()); } -Status IrEmitter::HandleConstant(HloInstruction* constant) { - VLOG(2) << "HandleConstant: " << constant->ToString(); - const Literal& literal = constant->literal(); - llvm::Constant* global_for_const; +Status IrEmitter::EmitConstantGlobals() { + for (const BufferAllocation& allocation : assignment_.Allocations()) { + if (!allocation.is_constant()) { + continue; + } - auto it = emitted_literals_.find(&literal); - if (it != emitted_literals_.end()) { - global_for_const = it->second; - } else { - global_for_const = EmitGlobalForLiteral(literal); - emitted_literals_[&literal] = global_for_const; + const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); + llvm::Constant* global_for_const; + auto it = emitted_literals_.find(&literal); + if (it != emitted_literals_.end()) { + global_for_const = it->second; + } else { + global_for_const = EmitGlobalForLiteral(literal); + InsertOrDie(&emitted_literals_, &literal, global_for_const); + } + + InsertOrDie(&constant_buffer_to_global_, allocation.index(), + global_for_const); } - emitted_value_[constant] = global_for_const; - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); - VLOG(2) << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); + return Status::OK(); } +Status IrEmitter::HandleConstant(HloInstruction* constant) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + // IrEmitter::EmitConstantGlobals has already taken care of emitting the body + // of the constant. + return EmitTargetAddressForOp(constant); +} + Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. @@ -472,23 +499,11 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -StatusOr IrEmitter::EmitTargetElementLoopBodyForMap( - HloMapInstruction* map, const llvm_ir::IrArray::Index& index) { - llvm::Function* mapped_ir_function = - FindOrDie(emitted_functions_, map->to_apply()); - std::vector parameter_addresses; - for (const HloInstruction* operand : map->operands()) { - const llvm_ir::IrArray& array = GetIrArrayFor(operand); - parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_)); - } - return EmitElementFunctionCall(mapped_ir_function, map->shape(), - parameter_addresses, "map_function"); -} - -Status IrEmitter::HandleMap(HloInstruction* map) { - return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) { - return EmitTargetElementLoopBodyForMap(Cast(map), index); - }); +llvm::Value* IrEmitter::EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice elemental_operands, + tensorflow::StringPiece name) { + return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( @@ -496,9 +511,6 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( const llvm_ir::IrArray::Index& index) { const HloInstruction* operand = reduce_window->operand(0); const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // We fold inputs into the accumulator and initialize it to // the initial value on the reduce_window. @@ -551,11 +563,10 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( // We are not in the padding, so carry out the computation. llvm_ir::IrArray input_array(GetIrArrayFor(operand)); - llvm::Value* input_value_address = - input_array.EmitArrayElementAddress(input_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce_window->shape(), - {accumulator_address, input_value_address}, "reducer_function"); + llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce_window->to_apply(), + {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); b_.CreateStore(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); @@ -611,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { "Dilation for SelectAndScatter is not implemented on CPU. "); } - // The select and scatter computations should have been emitted previously. - llvm::Function* select_function = - FindOrDie(emitted_functions_, select_and_scatter->select()); - llvm::Function* scatter_function = - FindOrDie(emitted_functions_, select_and_scatter->scatter()); - // Pseudo code for select-and-scatter: // // initialized_flag is initially off for every window, and is turned on after @@ -721,11 +726,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); - const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - select_function, output_shape, {selected_value_address, operand_address}, + llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* result = EmitThreadLocalCall( + *select_and_scatter->select(), + {b_.CreateLoad(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the @@ -752,14 +758,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); - llvm::Value* source_value_address = - source_array.EmitArrayElementAddress(source_index, &b_); + llvm::Value* source_value = + source_array.EmitReadArrayElement(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); - llvm::Value* output_value_address = - output_array.EmitArrayElementAddress(selected_index, &b_); - llvm::Value* scatter_value = EmitElementFunctionCall( - scatter_function, source->shape(), - {output_value_address, source_value_address}, "scatter_function"); + llvm::Value* output_value = + output_array.EmitReadArrayElement(selected_index, &b_); + llvm::Value* scatter_value = + EmitThreadLocalCall(*select_and_scatter->scatter(), + {output_value, source_value}, "scatter_function"); output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); @@ -1236,46 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); - auto param_number = parameter->parameter_number(); - auto param_shape = parameter->shape(); - - // We have to access the parameter at offset param_number in the params - // array. The code generated here is equivalent to this C code: - // - // i8* param_address_untyped = params[param_number]; - // Param* param_address_typed = (Param*)param_address_untyped; - // - // Where Param is the actual element type of the underlying buffer (for - // example, float for an XLA F32 element type). - llvm::Value* params = compute_function_->parameters_arg(); - llvm::Value* param_address_offset = - llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset); - param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() - .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. - param_address_untyped->setMetadata( - llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); - } - - llvm::Value* param_address_typed = b_.CreateBitCast( - param_address_untyped, IrShapeType(param_shape)->getPointerTo()); - emitted_value_[parameter] = param_address_typed; - - if (!ShapeUtil::IsOpaque(param_shape)) { - AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); - } - - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); - return Status::OK(); + return EmitTargetAddressForOp(parameter); } // Returns true if the relative order of the unreduced dimensions stays the same @@ -1739,9 +1706,6 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); gtl::ArraySlice dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1781,10 +1745,9 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( CHECK(index.end() == it); // Apply the reduction function to the loaded value. - llvm::Value* input_address = - arg_array.EmitArrayElementAddress(input_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce->shape(), {accumulator_addr, input_address}, + llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, "reduce_function"); b_.CreateStore(result, accumulator_addr); @@ -1830,6 +1793,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) { return Unimplemented("Send-done is not implemented on CPU."); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on CPUs."); +} + Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); @@ -2122,18 +2089,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); - std::vector parameter_addresses; - for (const HloInstruction* operand : call->operands()) { - parameter_addresses.push_back(GetEmittedValueFor(operand)); - } - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector call_args = GetArrayFunctionCallArguments( - parameter_addresses, &b_, computation->name(), + {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), /*temp_buffers_arg=*/GetTempBuffersArgument(), @@ -2144,8 +2106,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { call_args, root->shape(), root->outer_dimension_partitions(), &b_, call_ir_function, computation->name())); } else { - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - emitted_value_[call], computation->name()); + EmitGlobalCall(*computation, computation->name()); } return Status::OK(); @@ -2226,12 +2187,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); - // The called computation should have been emitted previously. - llvm::Function* condition_ir_function = - FindOrDie(emitted_functions_, condition); - llvm::Function* body_ir_function = - FindOrDie(emitted_functions_, xla_while->while_body()); - // Generating: // while (Condition(while_result)) { // // CopyInsertion pass inserts copies which enable 'while_result' to @@ -2248,12 +2203,10 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. - llvm::Value* while_result = GetEmittedValueFor(xla_while); - llvm::Value* while_condition = EmitElementFunctionCall( - condition_ir_function, condition->root_instruction()->shape(), - {while_result}, IrName(xla_while, "cond")); + EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); llvm::Value* while_predicate = b_.CreateICmpNE( - while_condition, + b_.CreateLoad( + GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2268,8 +2221,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { b_.SetInsertPoint(body_bb); // Calls the body function. - EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, - IrName(xla_while, "body")); + EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); + // Finishes with a branch back to the header. b_.CreateBr(header_bb); @@ -2437,8 +2390,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { Status IrEmitter::HandleConditional(HloInstruction* conditional) { auto pred = conditional->operand(0); - auto true_arg = conditional->operand(1); - auto false_arg = conditional->operand(2); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && pred->shape().element_type() == PRED) << "Predicate on a Conditional must be bool; got: " @@ -2460,13 +2411,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { << " and " << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); - llvm::Function* true_function = - FindOrDie(emitted_functions_, true_computation); - llvm::Function* false_function = - FindOrDie(emitted_functions_, false_computation); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); - llvm::Value* conditional_result = GetEmittedValueFor(conditional); // Generating: // if (pred) @@ -2483,12 +2428,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); SetToFirstInsertPoint(if_data.true_block, &b_); - EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, - conditional_result, IrName(conditional, "_true")); + EmitGlobalCall(*conditional->true_computation(), + IrName(conditional, "_true")); SetToFirstInsertPoint(if_data.false_block, &b_); - EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, - conditional_result, IrName(conditional, "_false")); + EmitGlobalCall(*conditional->false_computation(), + IrName(conditional, "_false")); SetToFirstInsertPoint(if_data.after_block, &b_); return Status::OK(); @@ -2506,6 +2451,23 @@ Status IrEmitter::HandleIota(HloInstruction* iota) { return Unimplemented("Iota is not implemented on CPU."); } +Status IrEmitter::HandleRng(HloInstruction* rng) { + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : rng->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_); + }; + } + + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + TF_RETURN_IF_ERROR(EmitTargetElementLoop( + rng, elemental_emitter.MakeElementGenerator(rng, operand_to_generator))); + + llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); + + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2672,40 +2634,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { - llvm::Type* element_type = IrShapeType(target_shape); - // The alignment and number of bytes within the temporary buffer is determined - // by the maximal shape as determined by buffer assignment. - const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); - if (allocation.is_thread_local()) { + const BufferAllocation& allocation = *slice.allocation(); + llvm::Value* tempbuf_address = [&]() -> llvm::Value* { + if (slice == computation_root_allocation_) { + llvm::Argument* retval = compute_function_->result_arg(); + llvm::AttrBuilder attr_builder; + attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); + attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); + retval->addAttrs(attr_builder); + return retval; + } + + auto param_it = + computation_parameter_allocations_.find(slice.allocation()->index()); + if (param_it != computation_parameter_allocations_.end()) { + int64 param_number = param_it->second; + // We have to access the parameter at offset param_number in the params + // array. The code generated here is equivalent to this C code: + // + // i8* param_address_untyped = params[param_number]; + // Param* param_address_typed = (Param*)param_address_untyped; + // + // Where Param is the actual element type of the underlying buffer (for + // example, float for an XLA F32 element type). + llvm::Value* params = compute_function_->parameters_arg(); + llvm::Value* param_address_offset = + llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); + llvm::LoadInst* param_address_untyped = + b_.CreateLoad(param_address_offset); + + if (!ShapeUtil::IsOpaque(target_shape)) { + AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); + AttachDereferenceableMetadataForLoad(param_address_untyped, + target_shape); + } + return param_address_untyped; + } + // Thread-local allocations should only be assigned a single buffer. const auto& assigned_buffers = allocation.assigned_buffers(); CHECK_EQ(1, assigned_buffers.size()); const Shape& shape = assigned_buffers.begin()->first->shape(); - llvm::AllocaInst*& tempbuf_address = - thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}]; - if (tempbuf_address == nullptr) { - tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( + std::pair key = { + compute_function_->function(), slice}; + auto buf_it = thread_local_buffers_.find(key); + if (buf_it == thread_local_buffers_.end()) { + llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( IrShapeType(shape), tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, MinimumAlignmentForShape(target_shape)); + auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); + CHECK(it_inserted_pair.second); + buf_it = it_inserted_pair.first; } - return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo()); - } + return buf_it->second; + }(); + return b_.CreateBitCast(tempbuf_address, + IrShapeType(target_shape)->getPointerTo()); +} +llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( GetTempBuffersArgument(), slice.index(), &b_); llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() + if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2720,85 +2718,25 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } return b_.CreateBitCast(tempbuf_address_untyped, - element_type->getPointerTo()); -} - -// Emits a function call returning a single array element. Allocates space -// for a single element_type value, and loads it after call. -llvm::Value* IrEmitter::EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* return_value_buffer = EmitArrayFunctionCall( - function, return_shape, 1, parameter_addresses, name); - return b_.CreateLoad( - return_value_buffer, - AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); -} - -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - b_.CreateCall(function, - GetArrayFunctionCallArguments( - parameter_addresses, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); + IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* elements = - llvm::ConstantInt::get(b_.getInt64Ty(), element_count); - PrimitiveType return_type = return_shape.element_type(); - llvm::Value* return_value_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, - tensorflow::strings::StrCat(name, "_return_value_address"), &b_, - MinimumAlignmentForPrimitiveType(return_type)); - EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, - name); - return return_value_buffer; +llvm::Value* IrEmitter::EmitTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + if (slice.allocation()->is_thread_local()) { + return EmitThreadLocalTempBufferPointer(slice, target_shape); + } else if (slice.allocation()->is_constant()) { + return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); + } else { + return EmitGlobalTempBufferPointer(slice, target_shape); + } } Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { - llvm::Value* addr; const Shape& target_shape = op->shape(); - if (op == op->parent()->root_instruction()) { - // For the root node, we write directly to the output buffer of the - // function. - llvm::Argument* retval = compute_function_->result_arg(); - if ((ShapeUtil::IsArray(target_shape) && - !ShapeUtil::IsZeroElementArray(target_shape)) || - (ShapeUtil::IsTuple(target_shape) && - !ShapeUtil::IsEmptyTuple(target_shape))) { - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - } - addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); - } else { - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - addr = EmitTempBufferPointer(slice, target_shape); - } + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2903,20 +2841,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -StatusOr IrEmitter::EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector& arguments, tensorflow::StringPiece name) { - llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); - std::vector argument_addrs; - for (auto argument : arguments) { - llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( - argument->getType(), "arg_addr", &b_); - b_.CreateStore(argument, argument_addr); - argument_addrs.push_back(argument_addr); +llvm::Value* IrEmitter::EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice parameters, + tensorflow::StringPiece name) { + const Shape& return_shape = callee.root_instruction()->shape(); + + // Lifting this restriction to allow "small" arrays should be easy. Allowing + // larger arrays is difficult because we allocate the buffer for this return + // value on the stack. + CHECK(ShapeUtil::IsScalar(return_shape)); + + PrimitiveType return_type = return_shape.element_type(); + + std::vector parameter_addrs; + for (llvm::Value* parameter : parameters) { + CHECK(!parameter->getType()->isPointerTy()); + llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( + parameter->getType(), "arg_addr", &b_); + b_.CreateStore(parameter, parameter_addr); + parameter_addrs.push_back(parameter_addr); + } + + llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(return_type, module_), + tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + MinimumAlignmentForPrimitiveType(return_type)); + + b_.CreateCall( + FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); + + return b_.CreateLoad(return_value_buffer); +} + +void IrEmitter::EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name) { + b_.CreateCall(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); +} + +llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( + const HloComputation& callee) { + const HloInstruction* root_inst = callee.root_instruction(); + if (root_inst->opcode() == HloOpcode::kOutfeed) { + return llvm::Constant::getNullValue(b_.getInt8PtrTy()); } - return EmitElementFunctionCall(llvm_function, - ShapeUtil::MakeShape(return_type, {}), - argument_addrs, name); + + const BufferAllocation::Slice root_buffer = + assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); + return EmitTempBufferPointer(root_buffer, root_inst->shape()); } + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 4e928ffadc9ff0423ed970fd2099b7e8eb5a2aee..c9a1dab62dcbcd926baa82737d24efa03fd326e9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -100,10 +100,14 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b() { return &b_; } - // Emits a call to `computation` with scalar arguments `arguments`. - StatusOr EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector& arguments, tensorflow::StringPiece name); + // Emit an LLVM global variable for every constant buffer allocation. + Status EmitConstantGlobals(); + + // Emit code to map one element according to `map_instr`. + llvm::Value* EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice elemental_operands, + tensorflow::StringPiece name); protected: // @@ -140,15 +144,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; - Status HandleMap(HloInstruction* map) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; Status HandleIota(HloInstruction* iota) override; + Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -214,9 +219,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emits code that computes the address of the given temporary buffer to the - // function. target_shape is the shape of this temporary buffer. - // The returned Value's type is a pointer to element_type. + // Helper for EmitTempBufferPointer. + llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); + + // Helper for EmitTempBufferPointer. + llvm::Value* EmitThreadLocalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape); + + // Emits code that computes the address of the given buffer allocation slice. + // + // TODO(sanjoy): This should be renamed to reflect that it no longer provides + // access to just temporaries. llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); @@ -228,44 +242,27 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::StringPiece function_name_suffix); // Used for LLVM IR register names. - // Methods that emit a function call. - // Parameters: - // function - The LLVM function to call. - // return_shape - The return shape of the HLO computation that was used to - // make the function. Not the same as the return type of the function - // in LLVM, since we use output parameters for the return type. - // element_count - number of elements to return (array form only). - // parameter_addresses - pointers to be passed to the function as - // parameters. - // name - used for LLVM IR register names. - - // Emits a function call, returning a scalar, often an element of a larger - // array. Returns a Value for the scalar element returned by the function. - llvm::Value* EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - tensorflow::gtl::ArraySlice parameter_addresses, + // Emits a call to a thread local function (e.g. to the computation nested + // within a reduce or a map). Thread local callees (by definition) only write + // to and read from thread local allocations. + // + // `parameters` holds the *scalar values* that need to be passed to the + // callee. The return value is the scalar returned by the callee. + llvm::Value* EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice parameters, tensorflow::StringPiece name); - // Array function call emitter. Stores the function's result into a supplied - // buffer. - // Parameters: - // function - The LLVM function to call. - // parameter_addresses - pointers to be passed to the function as - // parameters. - // return_value - pointer to a buffer where the call result is stored. - - void EmitArrayFunctionCallInto( - llvm::Function* function, - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Array function call emitter. Returns a Value for the function's return - // value buffer address. The return value buffer is alloca'ed by this - // function. - llvm::Value* EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - tensorflow::gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name); + // Emits a call to a "global" function (e.g. to the computation nested within + // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to + // the parameters and return values for these computations so there is no need + // to explicitly pass parameters or return results. + void EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name); + + // Returns the buffer to which a global call to `callee` would have written + // its result. + llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. @@ -404,11 +401,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { NameUniquer name_uniquer_; // Map containing all previously emitted computations. - std::map emitted_functions_; + std::map emitted_functions_; // Map containing all previously emitted thread-local temporary buffers. - std::map, - llvm::AllocaInst*> + std::map, llvm::Value*> thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory @@ -418,6 +414,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::unique_ptr compute_function_; llvm::IRBuilder<> b_; + // The buffer allocation slice for the root of the computation being compiled. + // Only relevant for thread local computations. + BufferAllocation::Slice computation_root_allocation_; + + // Maps the buffer allocation slices for the parameters to the computation + // being compiled to their parameter numbers. Only relevant for thread local + // computations. + tensorflow::gtl::FlatMap + computation_parameter_allocations_; + // Maps HLO instructions to their index into the profile counter array. const std::unordered_map instruction_to_profile_idx_; @@ -559,6 +565,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { LiteralPtrHashFunctor, LiteralPtrEqualityFunctor> emitted_literals_; + tensorflow::gtl::FlatMap + constant_buffer_to_global_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 6aff838462ac6bfe8a31971108a721b66dbe45bd..2db4d000f5b149969c88fb4325ca28aa11dc3708 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -80,9 +80,16 @@ void IrFunction::Initialize(const string& function_name, // void function(i8* retval, i8* run_options, i8** params, i8** temps, // i64* dynamic_loop_bounds, i64* prof_counters) // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. + // For thread local functions: + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: is null + // + // For global functions: + // retval: is null + // params: is null + // temps: address of an array with pointers to temporary buffers and entry + // computation parameters. // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -196,18 +203,25 @@ std::vector GetArrayFunctionCallArguments( llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = - b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); - llvm::Value* slot_in_param_addresses = - b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); - b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + llvm::Value* parameter_addresses_buffer; + + if (parameter_addresses.empty()) { + parameter_addresses_buffer = + llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo()); + } else { + parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = + b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat( + name, "_parameter_", i, "_address_as_i8ptr"))); + llvm::Value* slot_in_param_addresses = + b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); + b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + } } const auto to_int8_ptr = [=](llvm::Value* ptr) { diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index ec0498e04e2d4c4c1c20cbcd724feb2bf3c7322d..cef5e57b0b12b7ae93af0d2508b2b9d6a592d390 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" @@ -54,44 +55,12 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, llvm::IRBuilder<> b(vector_tanh_body); llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); + fast_math_flags.setFast(enable_fast_math); b.setFastMathFlags(fast_math_flags); - VectorSupportLibrary vsl(F32, vector_width, &b, "tanh_f32"); - llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(input->getType(), vsl.vector_type()); - - // This implements the same rational interpolant as implemented in Eigen3. - llvm::Value* input_clamped = - vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0)); - - std::array numerator_coeffs{ - -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, - 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, - 4.89352455891786e-03f}; - - std::array denominator_coeffs{ - 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, - 4.89352518554385e-03f}; - - llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped); - llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0])); - for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = - vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i])); - } - - numerator = vsl.Mul(input_clamped, numerator); - - llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0])); - for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = vsl.MulAdd(input_squared, denominator, - GetIeeeF32(denominator_coeffs[i])); - } - - llvm::Value* result = vsl.Div(numerator, denominator); - b.CreateRet(result); + CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); + b.CreateRet(llvm_ir::EmitFastTanh(&b, input)); DCHECK(!llvm::verifyFunction(*vector_tanh_function)); return vector_tanh_function; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index d03da46575b331de113cc5f33c2b4267504e8308..a5f34908d70dd18ec017bdf9833c7df40f80db07 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -58,13 +59,14 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, // [partition1_dim2_start] // [partition1_dim2_limit] // -void __xla_cpu_runtime_ParallelForkJoin( +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, void** temps, uint64* prof_counters, int32 num_partitions, int64* partitions, int32 num_partitioned_dims, void* function_ptr) { VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); const xla::ExecutableRunOptions* run_options = @@ -79,9 +81,9 @@ void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, temps, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, params, temps, + function(result_ptr, run_options_ptr, nullptr, temps, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 39b13183ff093611a42b3931d45f64eadb420622..a71a85913cfef271bc2a226cb0cf2dd4204499a4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" using tensorflow::int32; @@ -77,27 +78,24 @@ void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } // namespace -void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr, - Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index f8c8dd5e93d53db8d87be0208b5cf4daac3464f1..997fdd2ab309f0b68a9dbd0f156a8dc19955b437 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -23,6 +23,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool" +#include "tensorflow/core/platform/dynamic_annotations.h" using tensorflow::int32; using tensorflow::int64; @@ -74,10 +75,9 @@ void MatMulF64(const void* run_options_ptr, double* out, double* lhs, } // namespace -void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread @@ -88,11 +88,11 @@ void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } + // BLAS GEMM API for 64-bit Matrix Multiplication -void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread @@ -103,22 +103,26 @@ void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } -void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, - float* out, float* lhs, - float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, float* rhs, + int64 m, int64 n, int64 k, + int32 transpose_lhs, + int32 transpose_rhs) { // Set the thread number to 1 for single threaded excution. int prev_num_threads = mkl_set_num_threads_local(1); MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } -void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, - double* out, double* lhs, - double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { // Set the thread number to 1 for single threaded excution. int prev_num_threads = mkl_set_num_threads_local(1); MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 17303e2f0d34e531a3a56aa147608b949e0f43ae..16692e7f2e6145b2649b67987eef47916e958be2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -17,6 +17,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" using tensorflow::int32; @@ -71,7 +72,8 @@ void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, } // namespace -void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { @@ -79,16 +81,22 @@ void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, + float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( - const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index d9e8dcaed98dc8bc33bf1355b0acddb56faa3c71..f227e4ae139b92e56786e38ef8eef72c9e2cd424 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index b4c33e2f6cad8455e1e4ff2f020a167121e80a6f..181cec3cdddeb40daf5276d9d1d6a139417a6072 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -135,9 +135,9 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index 0d45918d0990c0f00c08f321dad015cfbc038bf3..c35569c6619ba5b534c5d8bb7ad683d84b6ecf4b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -220,7 +220,7 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // The body adds the reduced value of the Infeed data (first tuple element) // to the previous accumulator, and returns the accumulator and the continue // flag (second tuple element) as a tuple. - const auto build_body = [this, &result_shape](const Shape& infeed_shape) { + const auto build_body = [&result_shape](const Shape& infeed_shape) { XlaComputation body; XlaBuilder builder("body"); auto prev = Parameter(&builder, 0, result_shape, "prev"); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 097fa23027bf55ad0b92c347c5a1209bb5836695..86d57581f84920e8005e8f3c420e7488fc095434 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -106,6 +106,7 @@ class DfsHloVisitorBase { virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -233,6 +234,7 @@ class DfsHloVisitorBase { virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; virtual Status HandleGather(HloInstructionPtr hlo) = 0; + virtual Status HandleScatter(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index f4316e0fb77855aad1c4710908df09c604da896e..617a5a2eb4796d8003099e39e3d26389e532e954 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } + Status HandleAllToAll(HloInstructionPtr crs) override { + return DefaultAction(crs); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } @@ -194,6 +197,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleScatter(HloInstructionPtr scatter) override { + return DefaultAction(scatter); + } Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index b58b87a978a07c0b0adffd5ecb04b53de8ab9e7b..6aab317ca5b89fec1a01f92295a41c9fc26ccee1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1223,168 +1223,265 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( return source_index; } -llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( +StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { - PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type(); - llvm::Type* param_ir_type = - llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_); - - // Same values as PCG library - // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h - llvm::Value* multiplier = - b_->getInt(llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4})); - llvm::Value* increment = - b_->getInt(llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D})); - - auto random_value_from_hlo = [hlo]() { - const HloModule* module = - hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent() - : hlo->parent()->parent(); - return module->RandomNew64(); - }; + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const { + TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean, + operand_to_generator.at(hlo->operand(0))(index)); + TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma, + operand_to_generator.at(hlo->operand(1))(index)); + PrimitiveType elem_prim_ty = hlo->shape().element_type(); + llvm::Type* elem_ir_ty = + llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); + llvm::Type* raw_value_ty = raw_value->getType(); + + // Convert raw integer to float in range [0, 1) if the element is a float. + llvm::Value* elem_value = raw_value; + if (elem_ir_ty->isFloatingPointTy()) { + unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); + CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); + // Perform the division using the float type with the same number of bits + // as the raw value to avoid overflow. + if (raw_value_size_in_bits == 32) { + elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy()); + elem_value = b_->CreateFDiv( + elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + } else { + elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy()); + elem_value = b_->CreateFDiv( + elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); + } - // Seed each RNG emitter with a new 64-bit seed from the HloModule. If the - // compilation order is deterministic (i.e., RandomNew64 invocation order is - // deterministic), then the order of RNG is deterministic for a given seed and - // hence tests will be deterministic. - // If the user provides a global seed instruction then we only use 64-bits of - // the host's random number generator to seed the 128 bit value with the other - // 64-bits is due to a user specified global seed instruction. - // Create a GlobalVariable to maintain state between invocations. There is a - // bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit + if (elem_ir_ty != elem_value->getType()) { + elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty); + } + } + + // Convert the value for the requested distribution. + switch (hlo->random_distribution()) { + case RNG_UNIFORM: { + if (elem_ir_ty->isFloatingPointTy()) { + return b_->CreateFAdd( + b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value), + a_or_mean); + } else { + // To generate a uniform random value in [a, b) from a raw random sample + // in range [0, 2^N), we let range = b - a and return + // (a + raw_value % range). If range is not a power of 2, raw values + // larger than (2^N - 2^N % range) are biased toward results in + // [a, a + (limit % range)). An unbiased algorithm would need to drop + // raw values and re-sample, but we don't do this because re-sampling in + // an efficient way is complex, and it's not clear that users need it. + // In particular, if one thread in a GPU warp needs to re-sample, we pay + // the same cost as if the whole warp were to re-sample. So an + // efficient re-sampling implementation on GPU would need to do + // nontrivial work to share entropy between threads in the warp. + auto range = b_->CreateSub(b_or_sigma, a_or_mean); + return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range)); + } + } + case RNG_NORMAL: { + TF_ASSIGN_OR_RETURN( + llvm::Value * r, + EmitErfcInv(elem_prim_ty, + b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), + elem_value))); + return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean); + } + default: + return InvalidArgument( + "unhandled distribution %s", + RandomDistribution_Name(hlo->random_distribution()).c_str()); + } +} + +namespace { + +// Checks that the primitive type is supported by the elemental IR emitter for +// Philox RNG and returns the number of elements in each 128 bit sample of the +// Philox RNG algorithm. +int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) { + // Calculate the number of elements, that is the number of random numbers, in + // a 128 bit sample. + switch (elem_prim_ty) { + case U32: + case S32: + case F32: + // The algorithm uses 32 bits to generate values for F16. + case F16: + return 4; + case U64: + case S64: + case F64: + return 2; + default: + // BF16 is converted to F16 by the hlo pass HloElementTypeConverter. + // Other data types are not supported by XLA random operation. + LOG(FATAL) << "Unrecognized primitive type for RNG " << elem_prim_ty; + } + return 0; +} + +// Calculates the four uint32 values for the 128-bit Philox sample. +std::array CalculateSampleValues( + llvm::Value* sample_idx, llvm::Value* hlo_random_value, + llvm::Value* global_random_number, llvm::Value* rng_state, + llvm::IRBuilder<>* b) { + llvm::Type* index_ty = sample_idx->getType(); + + std::array counter_values; + + // Use the sample index to initialize counter[0] and counter[1]. + unsigned index_ty_size_in_bits = index_ty->getPrimitiveSizeInBits(); + CHECK(index_ty_size_in_bits == 32 || index_ty_size_in_bits == 64); + if (index_ty_size_in_bits == 32) { + counter_values[0] = sample_idx; + counter_values[1] = b->getInt32(0); + } else { + std::tie(counter_values[0], counter_values[1]) = + llvm_ir::SplitInt64ToInt32s(b, sample_idx); + } + + // Xor the global state variable with the global random number seed and use + // the result to initialize counter[2] and counter[3]. + std::tie(counter_values[2], counter_values[3]) = llvm_ir::SplitInt64ToInt32s( + b, b->CreateXor(rng_state, global_random_number)); + + // The algorithm uses a 64 bit key, which is also interpreted as two uint32 // values. - llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable( - /*M=*/*module_, - /*Ty=*/b_->getInt64Ty(), - /*isConstant=*/false, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/b_->getInt64(random_value_from_hlo()), - /*Name=*/"state_ptr0"); - - // When the module config seed is 0, the expected result of a prng is a random - // value. Instead of using the random_value_from_hlo, we need a global random - // value as the graph seed. This is because if we use random_value_from_hlo - // here, then for a newly built hlo graph, it always gives the same number. - uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() - : GlobalRandomValue(); - llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable( - /*M=*/*module_, - /*Ty=*/b_->getInt64Ty(), - /*isConstant=*/false, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/b_->getInt64(graph_seed), - /*Name=*/"state_ptr1"); - - // We want each thread to use its own stream, so we modify the increment per - // thread. We want the increment to remain odd, so we shift the thread id left - // 1 and add it to the increment. - increment = b_->CreateAdd(increment, b_->CreateShl(EmitThreadId(), 1)); - - // PCG-XSL-RR algorithm - // http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf - // state = multiplier * state + increment - // return uint64_t(state ^ (state >> 64))) >>> (state >> 122) - // where ">>>" is bitwise rotation - auto get_next_i64 = [=]() { - llvm::Value* state0 = b_->CreateZExtOrTrunc( - b_->CreateLoad(state_ptr0, "state0"), b_->getInt128Ty()); - llvm::Value* state1 = b_->CreateShl( - b_->CreateZExtOrTrunc(b_->CreateLoad(state_ptr1, "state1"), - b_->getInt128Ty()), - 64); - llvm::Value* state = b_->CreateOr(state0, state1); - llvm::Value* updated = - b_->CreateAdd(b_->CreateMul(state, multiplier), increment); - b_->CreateStore(b_->CreateTrunc(updated, b_->getInt64Ty()), state_ptr0); - b_->CreateStore( - b_->CreateTrunc(b_->CreateLShr(updated, 64), b_->getInt64Ty()), - state_ptr1); - - return llvm_ir::CreateRor( - b_->CreateTrunc(b_->CreateXor(state, b_->CreateLShr(state, 64)), - b_->getInt64Ty()), - b_->CreateTrunc(b_->CreateLShr(state, 122), b_->getInt64Ty()), b_); - }; + llvm::Value* key_values[2]; + + // Use a module random number to initialize the key. + std::tie(key_values[0], key_values[1]) = + llvm_ir::SplitInt64ToInt32s(b, hlo_random_value); + + // Prepare the constants used in the Philox RNG Algorithm. + llvm::Value* philoxW32A = b->getInt32(0x9E3779B9); + llvm::Value* philoxW32B = b->getInt32(0xBB67AE85); + llvm::Value* philoxM4xW32A = b->getInt32(0xD2511F53); + llvm::Value* philoxM4xW32B = b->getInt32(0xCD9E8D57); + + // Compute the 128 bit value for the current sample by repeating the + // single round computation and key raising computation for ten times. + for (int round = 0; round < 10; ++round) { + // A single round of computation of the counter values is as follows: + // MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0); + // MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1); + // counter[0] = hi1 ^ counter[1] ^ key[0]; + // counter[1] = lo1; + // counter[2] = hi0 ^ counter[3] ^ key[1]; + // counter[3] = lo0; + llvm::Value* lo0; + llvm::Value* hi0; + std::tie(lo0, hi0) = + llvm_ir::UMulLowHigh32(b, philoxM4xW32A, counter_values[0]); + llvm::Value* lo1; + llvm::Value* hi1; + std::tie(lo1, hi1) = + llvm_ir::UMulLowHigh32(b, philoxM4xW32B, counter_values[2]); + counter_values[0] = + b->CreateXor(hi1, b->CreateXor(counter_values[1], key_values[0])); + counter_values[1] = lo1; + counter_values[2] = + b->CreateXor(hi0, b->CreateXor(counter_values[3], key_values[1])); + counter_values[3] = lo0; + key_values[0] = b->CreateAdd(key_values[0], philoxW32A); + key_values[1] = b->CreateAdd(key_values[1], philoxW32B); + } - auto get_next_uniform_float = [=]() { - return b_->CreateFDiv(b_->CreateUIToFP(get_next_i64(), param_ir_type), - llvm::ConstantFP::get(param_ir_type, 0x1p64)); - }; + return counter_values; +} +} // namespace + +// Implements the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +// +// The paper presents a few variants of the Philox algorithm, we picked the +// 4x32_10 version of the algorithm for the following reasons: +// . 4x32 uses 32-bit multiplication which is fast on GPUs. +// . The authors recommend the 10-round variant, and TensorFlow also uses it. +// +// Precondition: the RNG instruction is not fused. +llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) + const { + VLOG(3) << "Using philox RNG algorithm"; + CHECK(!hlo->IsFused()); + // A random number generated by the per module random number generator. + // This ensures that each RNG HLO generates a different random sequence. + llvm::Value* hlo_random_value = b_->getInt64(hlo->GetModule()->RandomNew64()); + // A value specified by the configuration or generated by a global random + // number generator. + llvm::Value* global_random_number = + b_->getInt64(hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() + : GlobalRandomValue()); + + int elems_per_sample = + GetNumberOfElementsPerPhiloxRngSample(hlo->shape().element_type()); + + // Allocate stack storage for the 128 bit sample as four int32. + llvm::Type* int32_ty = b_->getInt32Ty(); + llvm::Value* sample_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + int32_ty, /*element_count=*/b_->getInt32(4), "sample", b_); + + // Load the global state variable for the Philox RNG algorithm. + llvm::GlobalVariable* rng_state_ptr = + llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_); + llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value"); + + // Build and return the elemental IR generator to generate a random value for + // the element corresponding to the current thread. + // + // This elemental IR generator computes one sample with multiple random + // numbers but only returns one random number. As a result, neighboring + // threads may calculate the same sample unnecessarily. However, if the + // kernel containing the RNG hlo is unrolled, LLVM is able to optimize away + // the duplicated computation of the same sample. In particular, if the unroll + // factor is a multiplier of elems_per_sample, LLVM is able to completely + // remove such duplicated computation. If the unroll factor is a non-trivial + // factor of elems_per_sample, LLVM can only partially remove such duplicated + // computation. return [=](const llvm_ir::IrArray::Index& index) -> StatusOr { - switch (hlo->random_distribution()) { - case RNG_UNIFORM: { - TF_ASSIGN_OR_RETURN(llvm::Value * p, - operand_to_generator.at(hlo->operand(0))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * q, - operand_to_generator.at(hlo->operand(1))(index)); - if (primitive_util::IsFloatingPointType(param_prim_type)) { - return b_->CreateFAdd( - b_->CreateFMul(b_->CreateFSub(q, p), get_next_uniform_float()), - p); - } else { - auto r = b_->CreateSub(q, p); - auto leading_zeros = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::ctlz, {r, b_->getInt1(true)}, {param_ir_type}, - b_); - auto in_block = b_->GetInsertBlock(); - - // A terminator should be present iff we're emitting code - // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(b_->GetInsertPoint() == in_block->end(), - in_block->getTerminator() == nullptr); - - llvm::BasicBlock* body_block; - llvm::BasicBlock* out_block; - - if (b_->GetInsertPoint() == in_block->end()) { - body_block = - llvm_ir::CreateBasicBlock(nullptr, IrName(hlo, "rng_body"), b_); - out_block = - llvm_ir::CreateBasicBlock(nullptr, IrName(hlo, "rng_out"), b_); - llvm::BranchInst::Create(body_block, in_block); - } else { - body_block = - in_block->splitBasicBlock(b_->GetInsertPoint(), "rng_body"); - out_block = - body_block->splitBasicBlock(b_->GetInsertPoint(), "rng_out"); - body_block->getTerminator()->eraseFromParent(); - } - - SetToFirstInsertPoint(body_block, b_); - auto random = b_->CreateAnd( - b_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), - b_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), - leading_zeros)); - llvm::BranchInst::Create(out_block, body_block, - b_->CreateICmpULT(random, r), body_block); - SetToFirstInsertPoint(out_block, b_); - return b_->CreateAdd( - p, b_->CreateSelect(b_->CreateICmpEQ(p, q), - llvm::ConstantInt::get(param_ir_type, 0), - random)); - } - } - case RNG_NORMAL: { - TF_ASSIGN_OR_RETURN(llvm::Value * m, - operand_to_generator.at(hlo->operand(0))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * s, - operand_to_generator.at(hlo->operand(1))(index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * r, - EmitErfcInv( - param_prim_type, - b_->CreateFMul(llvm::ConstantFP::get(param_ir_type, 2.0), - get_next_uniform_float()))); - return b_->CreateFAdd(b_->CreateFMul(r, s), m); - } - default: - return InvalidArgument( - "unhandled distribution %s", - RandomDistribution_Name(hlo->random_distribution()).c_str()); + llvm::Type* index_ty = index.GetType(); + // Calculate the linear element index. + llvm::Value* elem_idx = index.linear(); + if (elem_idx == nullptr) { + elem_idx = index.Linearize(AsInt64Slice(hlo->shape().dimensions()), b_); } + + // Calculate the index for the 128 bit sample and the offset of the current + // element within the sample. + llvm::Value* elems_per_sample_value = + llvm::ConstantInt::get(index_ty, elems_per_sample); + llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value); + llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value); + + std::array counter_values = CalculateSampleValues( + sample_idx, hlo_random_value, global_random_number, rng_state, b_); + + // Store the four counter_values into the sample_address alloca so we can + // load the elem_offset'th one below. + for (int idx = 0; idx < 4; ++idx) { + b_->CreateStore(counter_values[idx], + b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx))); + } + + llvm::Type* int64_ty = b_->getInt64Ty(); + CHECK(elems_per_sample == 2 || elems_per_sample == 4); + llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty; + // Retrieve the raw value for the current element from the current sample. + llvm::Value* raw_elem_value = b_->CreateLoad( + b_->CreateInBoundsGEP( + b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()), + elem_offset), + "raw_elem_value"); + + return ConvertValueForDistribution(hlo, operand_to_generator, index, + raw_elem_value); }; } @@ -1517,10 +1614,6 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to officially document different behavior. start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); int64 largest_valid_start_index = input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i); @@ -1671,10 +1764,6 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to officially document different behavior. start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); llvm::Value* update_dim_size = index_typed_const(update_hlo->shape().dimensions(i)); @@ -2042,7 +2131,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), b_)); }; case HloOpcode::kRng: - return MakeRngElementGenerator(hlo, operand_to_generator); + return MakePhiloxRngElementGenerator(hlo, operand_to_generator); case HloOpcode::kPad: return [this, hlo, &operand_to_generator]( const IrArray::Index& padded_index) -> StatusOr { @@ -2056,7 +2145,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; default: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", HloOpcodeString(hlo->opcode()).c_str()); }; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index deba6bea0a9944183bb05d1f12807845ddcc6260..fcb34557a52d35ef30a5dee643171e17407d05c2 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -193,10 +193,17 @@ class ElementalIrEmitter { const HloModuleConfig& hlo_module_config_; private: - // Returns a ElementGenerator for a RNG HloInstruction. - llvm_ir::ElementGenerator MakeRngElementGenerator( + // Returns a ElementGenerator for an RNG HloInstruction using the Philox + // random number generation algorithm. + llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const; + // Converts the raw value generated by a random number generation algorithm + // to the distribution requested by the RNG HloInstruction. + StatusOr ConvertValueForDistribution( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 6794cfe297b0fb9a15eb9b7e6906d225f9597d07..228c3fac95c3114484637bd93ec51c60b44403cc 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -25,7 +25,7 @@ limitations under the License. namespace xla { AsyncExecution::AsyncExecution(Backend* backend, - std::vector streams, + std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result) : backend_(CHECK_NOTNULL(backend)), @@ -46,9 +46,10 @@ Status AsyncExecution::BlockUntilDone() const { ExecutionTracker::ExecutionTracker() : next_handle_(1) {} -ExecutionHandle ExecutionTracker::Register( - Backend* backend, std::vector streams, - const ExecutionProfile& profile, GlobalDataHandle result) { +ExecutionHandle ExecutionTracker::Register(Backend* backend, + std::vector streams, + const ExecutionProfile& profile, + GlobalDataHandle result) { tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; auto inserted = handle_to_execution_.emplace( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 4458152dd9a98890fc3a3e7f324245ec68821467..4e9b9f883e26f5564a9c63a40d2b4b9348908214 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,7 +40,7 @@ namespace xla { // the stream when destructed. class AsyncExecution { public: - AsyncExecution(Backend* backend, std::vector streams, + AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); Status BlockUntilDone() const; @@ -54,7 +54,7 @@ class AsyncExecution { Backend* backend_; // Stream on which the execution is launched. - std::vector streams_; + std::vector streams_; // Profile object of the execution to be returned to the user. ExecutionProfile profile_; @@ -72,7 +72,7 @@ class ExecutionTracker { // Registers an execution with its backend, streams, and data handle to the // execution result. Returns a handle for the registered execution. ExecutionHandle Register(Backend* backend, - std::vector stream, + std::vector stream, const ExecutionProfile& profile, GlobalDataHandle data); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index e314a469f00abdb9f60ae812c0b78d273dc95dbe..0ce2db907b643f3beabd127388370dbe601179e1 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( void GenericTransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) { + MutableBorrowingLiteral literal, std::function done) { Status status = stream->BlockHostUntilDone(); if (!status.ok()) { return done(status); } - done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer)); + + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer, + literal)); } -StatusOr> -GenericTransferManager::TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { +Status GenericTransferManager::TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); @@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), device_buffer.on_host_shape())); - std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.on_host_shape()); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { @@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ - literal->untyped_data(index))); + literal.untyped_data(index))); } return Status::OK(); })); - return std::move(literal); + return Status::OK(); } Status GenericTransferManager::TransferLiteralToDeviceAsync( @@ -160,7 +158,7 @@ Status GenericTransferManager::TransferLiteralToInfeed( Status GenericTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { return Unimplemented("Generic transfer from Outfeed"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 3cd002c1bf3555cc2d2891c88b3ad648f8d9fd8c..6c1a21587a7ef5199afb93715dc57be5139fbc22 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) override; + void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) override; Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, @@ -53,7 +53,7 @@ class GenericTransferManager : public TransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; Status ResetDevices( tensorflow::gtl::ArraySlice executors) override; @@ -67,8 +67,9 @@ class GenericTransferManager : public TransferManager { const Shape& shape, se::DeviceMemoryBase* region) override; private: - StatusOr> TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer); + Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal); // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 72aff197fce52a8d0e20e6f1dcd8cc7a0a35af45..a3f6e8d9893528642e05354994c1d826949c6063 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -114,11 +114,13 @@ cc_library( srcs = ["hlo_to_ir_bindings.cc"], hdrs = ["hlo_to_ir_bindings.h"], deps = [ + ":buffer_allocations", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", @@ -142,6 +144,7 @@ cc_library( ], deps = [ ":backend_configs", + ":buffer_allocations", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -150,7 +153,6 @@ cc_library( ":ir_emission_utils", ":parallel_loop_emitter", ":partition_assignment", - ":while_transformer", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -163,6 +165,8 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:while_loop_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", @@ -217,6 +221,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -247,7 +252,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:pool", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", @@ -322,6 +327,7 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -451,6 +457,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":instruction_fusion", ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", @@ -463,6 +470,7 @@ tf_cc_test( name = "multi_output_fusion_test", srcs = ["multi_output_fusion_test.cc"], deps = [ + ":instruction_fusion", ":multi_output_fusion", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -536,6 +544,38 @@ cc_library( ], ) +cc_library( + name = "pad_for_tensor_cores", + srcs = ["pad_for_tensor_cores.cc"], + hdrs = ["pad_for_tensor_cores.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:shape_inference", + ], +) + +tf_cc_test( + name = "pad_for_tensor_cores_test", + srcs = ["pad_for_tensor_cores_test.cc"], + deps = [ + ":ir_emission_utils", + ":pad_for_tensor_cores", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + ], +) + cc_library( name = "gpu_transfer_manager", srcs = ["gpu_transfer_manager.cc"], @@ -580,9 +620,11 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":multi_output_fusion", + ":pad_for_tensor_cores", ":pad_insertion", ":partition_assignment", ":stream_assignment", + ":stream_executor_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -594,7 +636,6 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -614,7 +655,6 @@ cc_library( "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", - "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -707,6 +747,8 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], @@ -745,32 +787,17 @@ tf_cc_test( ], ) -cc_library( - name = "while_transformer", - srcs = ["while_transformer.cc"], - hdrs = ["while_transformer.h"], - deps = [ - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - ], -) - tf_cc_test( name = "while_transformer_test", srcs = ["while_transformer_test.cc"], deps = [ ":instruction_fusion", - ":while_transformer", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -806,6 +833,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index b095d4cd731bb7877baffbf69cb17bd50e101d6b..537295292b6ced72c4b2c456557b3c06e0aa5254 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -44,12 +44,22 @@ StatusOr> BufferAllocations::Builder::Build( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); + const int64 expected_alignment = [&] { + if (allocation.is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (allocation.is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. if (registered_buffers_.count(i)) { se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % - kEntryParameterAlignBytes != + if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %lld must be a multiple of %llx, but " @@ -62,7 +72,6 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; @@ -70,8 +79,7 @@ StatusOr> BufferAllocations::Builder::Build( OwningDeviceMemory buffer; TF_ASSIGN_OR_RETURN( buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); - if (reinterpret_cast(buffer.opaque()) % - kXlaAllocatedBufferAlignBytes != + if (reinterpret_cast(buffer.opaque()) % expected_alignment != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " @@ -165,5 +173,10 @@ void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index, buffers_[buffer_index] = buffer; } +bool ShouldEmitLiteralInLlvmIr(const Literal& literal) { + // LLVM can sometimes do interesting optimizations using scalar constants. + return ShapeUtil::IsScalar(literal.shape()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 636623502597b3a66523938ba430e9d5a82f796c..f13eab0dd787a2bfa687c991f9d808568360fd24 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -107,6 +107,12 @@ class BufferAllocations { bool torn_down_ = false; }; +// LLVM and PTXAS don't deal well with large constants, so we only emit very +// small constants directly in LLVM IR. Larger constants are emitted with zero +// initializers in LLVM IR and are later overwritten when the PTX/CUBIN is +// loaded. +bool ShouldEmitLiteralInLlvmIr(const Literal& literal); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5a63e65208ac3e8e23944bc31634f4d29d91c10c..7348307ec8a7286dfb733d6b9685862b20f11ac9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" namespace xla { namespace gpu { @@ -137,6 +138,28 @@ string NumBytesToString(int64 bytes) { tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); } +// Acquires a process-global lock on the device pointed to by the given +// StreamExecutor. +// +// This is used to prevent other XLA instances from trying to autotune on this +// device while we're using it. +tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + // se::Platform*s are global singletons guaranteed to live forever. + static auto* mutexes = + new std::map, + tensorflow::mutex>(); + + tensorflow::mutex_lock global_lock(mu); + auto it = mutexes + ->emplace(std::piecewise_construct, + std::make_tuple(stream_exec->platform(), + stream_exec->device_ordinal()), + std::make_tuple()) + .first; + return tensorflow::mutex_lock{it->second}; +} + } // anonymous namespace // We could have caching here so that we don't redo this work for two identical @@ -155,6 +178,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + // Don't run this function concurrently on the same GPU. + // + // This is a bit of a hack and doesn't protect us against arbitrary concurrent + // use of a GPU, but it's sufficient to let us compile two HLO modules + // concurrently and then run them sequentially. + tensorflow::mutex_lock lock = LockGpu(stream_exec_); + // Create a stream for us to do our work on. se::Stream stream{stream_exec_}; stream.Init(); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index b97a627d9bb89d9409cc0f45e52ce78119d8b575..cc38db27e2680e950f74e104cef8829585c7b81c 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -277,6 +278,16 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { case HloOpcode::kTanh: + // If we don't care much about precision, emit a fast approximation of + // tanh. + if (hlo_module_config_.debug_options().xla_enable_fast_math()) { + // Upcast F16 to F32 if necessary. + llvm::Type* type = + input_type == F16 ? b_->getFloatTy() : operand_value->getType(); + llvm::Value* input = b_->CreateFPCast(operand_value, type); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + return b_->CreateFPCast(fast_tanh, operand_value->getType()); + } return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); default: diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index b3a3c5dcb4d77889b65a119f09ddef9ba95d6b52..2fd2206324e5f763490780a54880825a772b7ea2 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -43,6 +43,8 @@ Status ForThunk::Initialize(const GpuExecutable& executable, Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { + VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " + << (hlo_instruction() ? hlo_instruction()->ToString() : ""); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (int64 i = 0; i < loop_limit_; ++i) { profiler->StartHloComputation(); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index dbc7754e251eb8075ab97dd2f36bbc400530fcf5..74282c568c09921dbeec2e9cce79b6c73b6ea592 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -31,16 +32,19 @@ namespace { // dimensions. struct MatrixDescriptor { MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose, - int64 matrix_num_rows, int64 matrix_num_cols) + int64 matrix_num_rows, int64 matrix_num_cols, + int64 matrix_batch_size) : data(matrix_data), transpose(needs_transpose), num_rows(matrix_num_rows), - num_cols(matrix_num_cols) {} + num_cols(matrix_num_cols), + batch_size(matrix_batch_size) {} se::DeviceMemoryBase data; bool transpose; // Whether this matrix needs to be transposed. int64 num_rows; int64 num_cols; + int64 batch_size; }; // Performs a gemm call without an explicit algorithm on lhs_matrix and @@ -50,6 +54,9 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { DCHECK(!output_matrix.transpose); + const int64 batch_size = lhs_matrix.batch_size; + CHECK_EQ(batch_size, rhs_matrix.batch_size); + CHECK_EQ(batch_size, output_matrix.batch_size); se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -60,13 +67,30 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, : se::blas::Transpose::kNoTranspose; auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + if (batch_size == 1) { + return stream + ->ThenBlasGemm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows) + .ok(); + } + + int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols; + int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols; + int64 output_stride = output_matrix.num_rows * output_matrix.num_cols; return stream - ->ThenBlasGemm( + ->ThenBlasGemmStridedBatched( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, - &output_data, /*leading dim of output=*/output_matrix.num_rows) + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/alpha, lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, + /*beta=*/0.0, &output_data, + /*leading dim of output=*/output_matrix.num_rows, output_stride, + batch_size) .ok(); } @@ -93,6 +117,10 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, se::blas::ProfileResult* output_profile_result) { DCHECK(!output_matrix.transpose); + CHECK_EQ(1, lhs_matrix.batch_size); + CHECK_EQ(1, rhs_matrix.batch_size); + CHECK_EQ(1, output_matrix.batch_size); + se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -141,9 +169,15 @@ StatusOr DoGemmAutotune( alpha, computation_type, algorithm, stream, &profile_result)); - if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; + if (profile_result.is_valid()) { + VLOG(3) << "cublas gemm algorithm " << algorithm << " took " + << profile_result.elapsed_time_in_ms() << "ms"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + } else { + VLOG(4) << "cublas gemm algorithm " << algorithm << " failed."; } } @@ -167,6 +201,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { return &DoGemm; case F64: return &DoGemm; + case C64: + return &DoGemm>; default: LOG(FATAL) << "Unsupported type."; } @@ -180,6 +216,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm; case F64: return &DoGemmWithAlgorithm; + case C64: + return &DoGemmWithAlgorithm>; default: LOG(FATAL) << "Unsupported type."; } @@ -192,6 +230,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { return &DoGemmAutotune; case F64: return &DoGemmAutotune; + case C64: + return &DoGemmAutotune>; default: LOG(FATAL) << "Unsupported type."; } @@ -210,6 +250,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF32; case F64: return se::blas::ComputationType::kF64; + case C64: + return se::blas::ComputationType::kComplexF32; default: LOG(FATAL) << "Unsupported type."; } @@ -263,12 +305,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::DeviceMemoryBase output_data = buffer_allocations.GetDeviceAddress(output_buffer_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(output_shape_)); + + int64 row_dim = dim_nums.lhs_batch_dimensions_size(); + int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; + int64 batch_size = std::accumulate(output_shape_.dimensions().begin(), + output_shape_.dimensions().end() - 2, 1, + std::multiplies()); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_NE(row_dim, batch_dim); + CHECK_NE(col_dim, batch_dim); + } + + // Verify that the non-batch dimensions are minor-most. This is required for + // efficient access. + for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) { + CHECK_LT(shape->layout().minor_to_major(row_dim), 2); + CHECK_LT(shape->layout().minor_to_major(col_dim), 2); + } + // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of // their layout. Therefore, we should treat dimension 0 as row and dimension 1 // as column when mapping a matrix Dot to BLAS gemm. - int64 output_num_rows = output_shape_.dimensions(0); - int64 output_num_cols = output_shape_.dimensions(1); + int64 output_num_rows = output_shape_.dimensions(row_dim); + int64 output_num_cols = output_shape_.dimensions(col_dim); // BLAS gemm expects the inputs and the output are in column-major order. // Therefore, we need to convert dot between row-major matrices to that @@ -291,34 +358,46 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // the leading dimension of the LHS matrix of gemm is the number of rows in // B^T and thus the number of columns in B. - auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, - bool transpose) -> MatrixDescriptor { - bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0; - bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) != - LayoutUtil::Minor(output_shape_.layout(), 0); - return MatrixDescriptor(data, transpose ^ layout_mismatch, - shape.dimensions(is_row_major), - shape.dimensions(!is_row_major)); + auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape, + bool transpose) -> MatrixDescriptor { + bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) != + LayoutUtil::Minor(output_shape_.layout(), row_dim); + return MatrixDescriptor( + data, transpose ^ layout_mismatch, + shape.dimensions(row_dim + static_cast(is_row_major)), + shape.dimensions(row_dim + static_cast(!is_row_major)), + batch_size); }; - DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); - const MatrixDescriptor lhs_descriptor = make_descriptor( - lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim); const MatrixDescriptor rhs_descriptor = make_descriptor( - rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. - auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { + auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { PrimitiveType element_type = output_shape_.element_type(); se::blas::ComputationType computation_type = GetBlasComputationType(element_type); + // TODO(b/112111608): Implement auto tune for batched gemm. + if (batch_size != 1) { + return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + alpha_, stream); + } + + auto thunk_name = [&] { + return hlo_instruction() != nullptr ? hlo_instruction()->ToString() + : ""; + }; + const string& device_name = stream->parent()->GetDeviceDescription().name(); auto autotune_it = autotune_results_.find(device_name); if (autotune_it == autotune_results_.end()) { + VLOG(3) << "Starting autotune of GemmThunk " << thunk_name(); StatusOr best_algorithm = GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, stream); @@ -326,11 +405,11 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, autotune_results_.insert({device_name, best_algorithm}).first; if (autotune_it->second.ok()) { - VLOG(2) << "Autotune on GemmThunk " << this + VLOG(2) << "Autotune on GemmThunk " << thunk_name() << " successful; best algorithm is " << best_algorithm.ValueOrDie(); } else { - VLOG(2) << "Autotune on GemmThunk " << this + VLOG(2) << "Autotune on GemmThunk " << thunk_name() << " unsuccessful. Will use generic gemm."; } } @@ -340,7 +419,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (best_algorithm.ok()) { auto algorithm = best_algorithm.ValueOrDie(); VLOG(2) << "Using algorithm " << algorithm - << " chosen by autotuning on GemmThunk " << this; + << " chosen by autotuning on GemmThunk " << thunk_name(); return GetGemmWithAlgorithmFn(element_type)( lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, algorithm, stream, @@ -355,16 +434,16 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); bool launch_ok; - if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { - launch_ok = launch( - lhs_descriptor, rhs_descriptor, - MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), - stream); + if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) { + launch_ok = launch(lhs_descriptor, rhs_descriptor, + MatrixDescriptor(output_data, false, output_num_rows, + output_num_cols, batch_size), + stream); } else { - launch_ok = launch( - rhs_descriptor, lhs_descriptor, - MatrixDescriptor(output_data, false, output_num_cols, output_num_rows), - stream); + launch_ok = launch(rhs_descriptor, lhs_descriptor, + MatrixDescriptor(output_data, false, output_num_cols, + output_num_rows, batch_size), + stream); } if (!launch_ok) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc index e6ddea6d2578bbb482c481a511cc8d8adb5fa2d6..7f0b030fece8f25578bd90a538279d455350278a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc @@ -30,5 +30,7 @@ const int64 kEntryParameterAlignBytes = 16; const int64 kXlaAllocatedBufferAlignBytes = tensorflow::Allocator::kAllocatorAlignment; +const int64 kConstantBufferAlignBytes = kXlaAllocatedBufferAlignBytes; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h index 925e6927b64625f011efe7b4b960421f41ddee79..6f5f1fa09c57dfd246d702c0adc92c7e2e76805a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h @@ -28,6 +28,9 @@ extern const int64 kEntryParameterAlignBytes; // out (result) buffers. extern const int64 kXlaAllocatedBufferAlignBytes; +// Minimum alignment for constant buffers. +extern const int64 kConstantBufferAlignBytes; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index fbc1303085b579e898d2f503a341754109768567..75f414e47fe3edcc1b10b392ed5cc5038be6c190 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -48,80 +48,17 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(*module)); - - // Make sure all operands of a library call are in memory instead of constants - // in IR. Also, init values of while and conditional nodes cannot be - // constants. Insert copies for any constants found at the operands of these - // nodes. - tensorflow::gtl::FlatSet inserted_copies; + // Check the assumption that the epsilon and feature_index constants of the + // CUDNN batchnorm op are not shared with other ops where we would replace + // them with a copy. These custom op calls are generated with the + // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. for (HloComputation* computation : module->computations()) { for (HloInstruction* hlo : computation->instructions()) { - // Inserts a copy of hlo->operand(n) if it's a constant. - auto copy_operand_if_constant = [&](int64 n) -> Status { - HloInstruction* operand = hlo->mutable_operand(n); - // Skip the operands that have already been replaced with a copy in a - // previous iteration (which is possible when a constant is used as an - // operand in multiple places). - if (ContainsKey(inserted_copies, operand)) { - return Status::OK(); - } - for (auto& pair : dataflow->GetInstructionValueSet(operand)) { - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction()->IsConstant() && - !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) { - HloInstruction* constant = value->defining_instruction(); - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - FindOrInsertCopy(constant)); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); - inserted_copies.insert(copy); - changed = true; - } - } - } - return Status::OK(); - }; - - if (IsCustomCallToDnnBatchNorm(*hlo)) { - // The epsilon and feature_index operands to a CUDNN batchnorm op don't - // need to be materialized in memory -- in fact, they must be constants. - // These are the last two operands of all three batchnorm ops. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo) || - hlo->opcode() == HloOpcode::kCrossReplicaSum || - hlo->opcode() == HloOpcode::kWhile || - hlo->opcode() == HloOpcode::kConditional) { - // For all other library calls, cross-replica-sum, while and conditional - // ops materialize all the operands into memory. (Cross-replica-sum - // gets its constant args materialized even if it's not implemented as a - // libcall to simplify the implementation. It's slower, but we can - // constant fold away constant args *anyway*, so we just need to make it - // work.) - for (int64 i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } + if (!IsCustomCallToDnnBatchNorm(*hlo)) { + continue; } - } - } - - if (changed) { - // Check the assumption that the epsilon and feature_index constants of the - // CUDNN batchnorm op are not shared with other ops where we would replace - // them with a copy. These custom op calls are generated with the - // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. - for (HloComputation* computation : module->computations()) { - for (HloInstruction* hlo : computation->instructions()) { - if (!IsCustomCallToDnnBatchNorm(*hlo)) { - continue; - } - for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); - ++i) { - CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); - } + for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); ++i) { + CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 0cad2958c72797b4d70f00676928b2b21d7a3e8d..bb7736efa65c49766ea88a43fdab2b7102100c11 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -84,7 +85,7 @@ Status GpuExecutable::ExecuteThunks( } // Stream 0 indicates `main_stream` and substreams start from stream 1. - std::vector::SmartPtr> sub_streams; + std::vector sub_streams; sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); @@ -181,6 +182,55 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } +StatusOr +GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { + tensorflow::mutex_lock lock(module_handle_mutex_); + auto it = module_globals_.find(executor); + if (it != module_globals_.end()) { + return &it->second; + } + + se::MultiModuleLoaderSpec module_spec; + if (!cubin().empty()) { + module_spec.AddCudaCubinInMemory(cubin()); + } + module_spec.AddCudaPtxInMemory(ptx().c_str()); + + tensorflow::gtl::FlatMap globals; + se::ModuleHandle module_handle; + executor->LoadModule(module_spec, &module_handle); + + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_constant()) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase global, + executor->GetUntypedSymbol( + llvm_ir::ConstantBufferAllocationToGlobalName(allocation), + module_handle)); + VLOG(3) << "Resolved global " + << llvm_ir::ConstantBufferAllocationToGlobalName(allocation) + << " to " << global.opaque(); + InsertOrDie(&globals, i, global); + + const Literal& literal = + llvm_ir::LiteralForConstantAllocation(allocation); + CHECK(ShapeUtil::IsArray(literal.shape())); + if (!ShouldEmitLiteralInLlvmIr(literal)) { + VLOG(3) << "H2D memcpy for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyH2D( + literal.untyped_data(), allocation.size(), &global)); + } + } + } + + module_handles_.emplace(executor, + se::ScopedModuleHandle(executor, module_handle)); + return &module_globals_.emplace(executor, std::move(globals)).first->second; +} + StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -192,6 +242,10 @@ StatusOr GpuExecutable::ExecuteOnStream( } BufferAllocations::Builder buffer_allocations_builder; + se::StreamExecutor* executor = run_options->stream()->parent(); + + TF_ASSIGN_OR_RETURN(auto* const globals, ResolveConstantGlobals(executor)); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); @@ -213,8 +267,12 @@ StatusOr GpuExecutable::ExecuteOnStream( buffer_allocations_builder.RegisterBuffer(i, buffer); } + + if (allocation.is_constant()) { + buffer_allocations_builder.RegisterBuffer(i, FindOrDie(*globals, i)); + } } - se::StreamExecutor* executor = run_options->stream()->parent(); + TF_ASSIGN_OR_RETURN( auto buffer_allocations, buffer_allocations_builder.Build( @@ -235,7 +293,7 @@ StatusOr GpuExecutable::ExecuteOnStream( // the respective location in ShapedBuffer. std::set buffers_in_result; TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus( - [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + [&buffer_allocations, &buffers_in_result, this]( const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); // The points-to set is unambiguous so the set should be a diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 80ec38c3ac114fe4ad9d56784330c1144d913db1..c7ce6d0acbbbe594040271c0d45c71c016e36514 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -66,7 +68,7 @@ class GpuExecutable : public Executable { } // Returns the compiled PTX for the computation. - tensorflow::StringPiece ptx() const { return ptx_; } + const string& ptx() const { return ptx_; } // Returns the cubin (compiled PTX) stored in this GpuExecutable. May be // empty, in which case compilation is left up to the GPU driver. @@ -98,6 +100,15 @@ class GpuExecutable : public Executable { // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; + using BufferAllocToDeviceMemoryMap = + tensorflow::gtl::FlatMap; + + // Loads the PTX or CUBIN for this executable into `executor` and resolves the + // globals corresponding to constant buffers. Returns a map mapping buffer + // allocation indices to GPU pointers. + StatusOr ResolveConstantGlobals( + stream_executor::StreamExecutor* executor); + // The LLVM IR, in string format, of the unoptimized module generated for this // GpuExecutable. We save a string instead of an llvm::Module* because leaving // llvm::Module* in a singleton can cause the heap checker to emit false @@ -126,6 +137,14 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; + // Cache of module handles and constant buffer allocation maps used by + // `ResolveConstantGlobals`. + tensorflow::mutex module_handle_mutex_; + std::map + module_handles_ GUARDED_BY(module_handle_mutex_); + std::map + module_globals_ GUARDED_BY(module_handle_mutex_); + TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 09ef62c87f8875a5803497e8eb628769f883202a..d033faee8d25ed81a1483f8314652ef999ab36c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -31,20 +31,13 @@ limitations under the License. namespace xla { namespace gpu { -using stream_executor::dnn::DataLayout; -using stream_executor::dnn::FilterLayout; - -static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { - int major, minor; - CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, - &minor)); - return major >= 7; -} +using se::dnn::DataLayout; +using se::dnn::FilterLayout; // Returns (input, filter, output) layouts. static std::tuple HeuristicLayoutAssignment(const HloInstruction* instr, - stream_executor::StreamExecutor* stream_executor) { + se::StreamExecutor* stream_executor) { // DataLayout and FilterLayout uses weird enum names. Translations: // N <=> Batch or Output // C <=> Depth or Input @@ -52,31 +45,44 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // W <=> X // // Therefore kOutputInputYX and kBatchDepthYX mean NCHW. + // + // If you have trouble keeping these straight, consider that all that matters + // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)? + + constexpr auto kAllNCHW = + std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + constexpr auto kAllNHWC = + std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); - // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x - // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version - // changes, as well as the hardware updates. + // If we're not Volta or not fp16, the decision is easy: Use NCHW. if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && IsVoltaOrLater(*stream_executor))) { - return std::make_tuple(DataLayout::kBatchDepthYX, - FilterLayout::kOutputInputYX, - DataLayout::kBatchDepthYX); + return kAllNCHW; } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); - // For BackwardInput that has stride, full NHWC layouts run significantly - // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + + // Empirically we've found with Volta and cudnn 7 that backward-input convs + // with stride are significantly faster with NCHW layouts. // - // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, - // NHWC). + // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW), + // which on paper gives good performance. However, there are two observations: + // * a mixed layout combination is more cuDNN-bug prone, based on empirical + // envidence. + // * we've also observed that for mixed layouts, cuDNN transposes data back + // and forth from a different layout combination. If we end up with + // transposes anyway, we prefer to have them in XLA, as they can be fused. + // TODO(timshen): Figure out the exact condition. This may be achieved by + // auto-tuning layouts offline. if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && window_util::HasStride(instr->window())) { - return std::make_tuple(DataLayout::kBatchYXDepth, - FilterLayout::kOutputInputYX, - DataLayout::kBatchDepthYX); + return kAllNCHW; } - return std::make_tuple(DataLayout::kBatchYXDepth, - FilterLayout::kOutputYXInput, - DataLayout::kBatchYXDepth); + + // For other Volta f16 convolutions, use NHWC. + return kAllNHWC; } // Adds layout constraints on the cudnn custom-call instruction. The layout @@ -170,6 +176,38 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); } + + // For batched dot we require the default layout. + // TODO(b/112111608): This is overly conservative, the only real restriction + // is that batch dimensions must be major. + if (instruction->opcode() == HloOpcode::kDot && + ImplementedAsGemm(*instruction) && + instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) { + // Verify that the batch dims come before the row and col dims. + const DotDimensionNumbers& dim_nums = + instruction->dot_dimension_numbers(); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(instruction->shape())); + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + } + + // Set both inputs and the output to default layout. + Shape op0_shape = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&op0_shape); + Shape op1_shape = instruction->operand(1)->shape(); + LayoutUtil::SetToDefaultLayout(&op1_shape); + Shape output_shape = instruction->shape(); + LayoutUtil::SetToDefaultLayout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op1_shape, instruction, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); + } } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 95f78ae29326caad2f0785e2ba285a996e685899..286547ebae2f1a4b8d783a06d13b4dd96052b952 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -31,6 +33,8 @@ namespace xla { namespace gpu { namespace { +namespace op = xla::testing::opcode_matchers; + using LayoutAssignmentTest = HloTestBase; TEST_F(LayoutAssignmentTest, Elementwise) { @@ -327,6 +331,33 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } } +TEST_F(LayoutAssignmentTest, DotLayout) { + const char* hlo_text = R"( + HloModule DotLayout + ENTRY dot { + p0 = f32[8,8,256,64]{3,1,2,0} parameter(0) + p1 = f32[8,8,256,64]{3,1,2,0} parameter(1) + ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1), + lhs_batch_dims={0,1}, lhs_contracting_dims={3}, + rhs_batch_dims={0,1}, rhs_contracting_dims={3} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + GpuLayoutAssignment layout_assignment(&computation_layout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = + ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 79b3f1efecdf06bfa93b17a1799f3009d517f3b5..a2f53f844613da9fe8166489dc9959e8d30c6332 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -117,38 +117,37 @@ StatusOr GpuTransferManager::TransferBufferToInfeedInternal( return std::move(buffer); } -static std::unique_ptr ShapeTreeToLiteral( +static void ShapeTreeToLiteral( ShapeTree>* shape_tree) { // This is a struct instead of a lambda for std::function-free recursion. struct Helper { - static std::unique_ptr helper( + static void helper( ShapeTree>* shape_tree, ShapeIndex* index) { const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); if (ShapeUtil::IsArray(shape)) { - return (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + return; } CHECK(ShapeUtil::IsTuple(shape)) << ShapeUtil::HumanStringWithLayout(shape); const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); index->push_back(0); - std::vector> tuple_operands; for (int64 i = 0; i < tuple_element_count; ++i) { index->back() = i; - tuple_operands.push_back(helper(shape_tree, index)); + helper(shape_tree, index); } index->pop_back(); - return LiteralUtil::MakeTupleOwned(std::move(tuple_operands)); } }; ShapeIndex index; - return Helper::helper(shape_tree, &index); + Helper::helper(shape_tree, &index); } Status GpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* /*executor*/, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { ShapeTree> outfeed_buffers( &literal_shape); @@ -162,6 +161,8 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( return; } *buffer = MakeUnique(GetByteSizeRequirement(shape)); + (*buffer)->set_destination( + MakeUnique(literal, index)); }); // Give the tree of buffers to the outfeed mananger. The device will fill it @@ -169,8 +170,8 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager(); outfeed_manager->EnqueueDestination(&outfeed_buffers); - // Now turn the tree of buffers back into a literal. - *literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers)); + // Now wait for the tree of buffers are written. + ShapeTreeToLiteral(&outfeed_buffers); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index dceeb9e2eb01a7dd5e978d819ed1db56d828f353..7929042869763dfeab2fe8f87093b7ea758337d0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -42,7 +42,7 @@ class GpuTransferManager : public GenericTransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index 19420e590d05892417da4d5e62fdcde5eba9d9f1..17226769302eef0dd01550b0bc5404e889ad78f8 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/ptr_util.h" @@ -37,10 +37,9 @@ void InitAndStartTimer(std::stack>* timers, stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); } -uint64 GetCyclesTaken( - std::stack>* timers, - const std::vector::SmartPtr>& sub_streams, - se::Stream* stream, double clock_rate_ghz) { +uint64 GetCyclesTaken(std::stack>* timers, + const std::vector& sub_streams, + se::Stream* stream, double clock_rate_ghz) { CHECK_GT(timers->size(), 0); stream->ThenWaitFor(&sub_streams); stream->ThenStopTimer(timers->top().get()); @@ -53,7 +52,7 @@ uint64 GetCyclesTaken( HloExecutionProfiler::HloExecutionProfiler( bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, + const std::vector& sub_streams, const HloComputation* computation) : do_profile_(do_profile), profile_(profile), diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h index 6654850bef3efa46028defbba81e3537fafbf143..80cde75f2bbb555f514fffea58ad92edf92fd0d1 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -38,10 +38,10 @@ class ScopedInstructionProfiler; class HloExecutionProfiler { public: // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler( - bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, - const HloComputation* computation); + explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, + se::Stream* stream, + const std::vector& sub_streams, + const HloComputation* computation); // If profiling is enabled, sets the total cycle count on the profile from the // execution timer. @@ -72,7 +72,7 @@ class HloExecutionProfiler { double clock_rate_ghz_; HloExecutionProfile* profile_; se::Stream* stream_; - const std::vector::SmartPtr>& sub_streams_; + const std::vector& sub_streams_; const HloComputation* computation_; std::stack> timers_; // Contains the HLO instructions for which we are currently measuring the diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 1b6315ec0305712d1367a9380f0de3eed91e2ee1..8c11cd05419289d82b033c936bb60884f45cb636 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -18,8 +18,10 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -110,6 +112,12 @@ void HloToIrBindings::EmitBasePointersForHlos( llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), index); + } else if (slice.allocation()->is_constant()) { + llvm::Value* global_for_constant = + module_->getGlobalVariable(llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + BindHloToIrValue(*non_io_hlo, global_for_constant); } else { const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); @@ -135,6 +143,14 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_); } +// Returns true if `value` has a name that should not be changed. +static bool HasMeaningfulName(llvm::Value* value) { + if (auto* global = llvm::dyn_cast(value)) { + return global->getLinkage() != llvm::GlobalValue::PrivateLinkage; + } + return false; +} + llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, ShapeIndexView shape_index, llvm::Value* ir_value) { @@ -149,8 +165,13 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, } else { typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); } - ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); - typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); + if (!HasMeaningfulName(ir_value)) { + ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); + } + if (!HasMeaningfulName(typed_ir_value)) { + typed_ir_value->setName( + llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); + } return typed_ir_value; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 8abae43a5ad908a35256d32d3ae41acd0aa89337..0f2c83aeb2633a007559d8caac78ea2d233539ed 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -73,6 +73,67 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { } } +// This function limits the maximum number of operands to a fusion. +// +// There's a cap on how many parameters we can pass to a CUDA kernel, but +// exactly what that limit is is hazy, as it depends on (among other things) how +// much GPU constant memory is in use for other purposes. +// +// Moreover, we don't even know at the point that we're running fusion how many +// arguments the CUDA kernel for a fusion node will have: It depends on buffer +// assignment, where we will decide which of the fusion's operands live in XLA's +// big temp buffer versus in other allocations. +// +// As a heuristic, we simply cap the number of fusion operands plus outputs at +// kMaxOperandsAndOutputsPerFusion. This puts an upper bound on the number of +// parameters to the kernel, working around the correctness problem. +// +// This limit is also often good for performance. In a fusion with many +// operands, each GPU thread likely has to do a lot of work, and so possibly +// uses a lot of registers, thus limiting occupancy. +/*static*/ bool GpuInstructionFusion::FusionWouldBeTooLarge( + const HloInstruction* a, const HloInstruction* b) { + // Compute the number of outputs of the (possibly multi-output) fusion node + // we're considering creating. + // + // This isn't precise; we may be off by one if + // - We're creating a multi-output fusion out of two non-MOFs. Creating a + // MOF adds a new buffer, namely, the tuple buffer. + // - We're merging two MOFs. In this case, we should count the tuple buffer + // only once. + // - WLOG there's an edge from `a` to `b` and `b` is the only consumer of + // `a`. In this case the result of `a` is not part of the output of the + // fusion. + // + // But because this is a heuristic and our limit + // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a + // big difference), we ignore this small inaccuracy in favor of simplicity. + int64 num_output_buffers = ShapeUtil::SubshapeCount(a->shape()) + + ShapeUtil::SubshapeCount(b->shape()); + + // The new fusion will have no more operands and outputs than + // producer_operands + consumer_operands - 1 + num_output_buffers + // (minus one because we may be fusing a producer->consumer edge between `a` + // and `b`). + // + // This fact may be enough to let us avoid having to compute the true total + // number of operands, which can be expensive. + if (a->operand_count() + b->operand_count() - 1 + num_output_buffers <= + kMaxOperandsAndOutputsPerFusion) { + return false; + } + + // Compute the precise number of operands to the new fusion. + tensorflow::gtl::FlatSet operands( + a->operands().begin(), a->operands().end()); + operands.insert(b->operands().begin(), b->operands().end()); + // If there's an edge between `a` and `b`, don't count it: We're fusing that + // producer -> consumer relationship. + operands.erase(a); + operands.erase(b); + return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion; +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); @@ -141,6 +202,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, IsIEEEFloatingPointScalarConstant(producer->operand(0)) && fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; } + return false; } // Other output fusions are not currently supported on GPUs. @@ -188,48 +250,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - // Limit the maximum number of operands to a fusion. - // - // There's a limit to how many parameters we can pass to a CUDA kernel, but - // exactly what that limit is is hazy, as it depends on (among other things) - // how much GPU constant memory is in use for other purposes. - // - // Moreover, we don't even know at this point how many arguments the CUDA - // kernel for this fusion node will have: It depends on buffer assignment, - // where we will decide which of the fusion's operands live in XLA's big temp - // buffer versus in other allocations. - // - // As a heuristic, we simply cap the number of fusion operands at - // kMaxOperandsPerFusion. This puts an upper bound on the number of - // parameters to the kernel, working around the correctness problem. - // - // This limit is also often good for performance. In a fusion with many - // operands, each GPU thread likely has to do a lot of work, and so possibly - // uses a lot of registers, thus limiting occupancy. - // - // We put this check last because it's expensive to compute. - - // The new fusion will have no more operands than - // producer_operands + consumer_operands - 1 - // (minus one because we're fusing the producer->consumer edge). This fact - // may be enough to let us avoid having to compute the true total number of - // operands, taking into account the fact that producer and consumer may share - // operands. - if (producer->operand_count() + consumer->operand_count() - 1 > - kMaxOperandsPerFusion) { - tensorflow::gtl::FlatSet producer_operands( - producer->operands().begin(), producer->operands().end()); - int64 new_num_operands = - producer->operand_count() + - c_count_if(consumer->operands(), [&](const HloInstruction* operand) { - return operand != producer && !producer_operands.count(operand); - }); - if (new_num_operands > kMaxOperandsPerFusion) { - return false; - } - } - - return true; + // We put this check last because it's potentially expensive. + return !FusionWouldBeTooLarge(consumer, producer); } bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index 5ee1c004b6142140a1848b3f85a8c2d3e0f41a42..c91f6343a69268ca687004dbe0ffbb863271a95c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,19 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + // Maximum number of operands plus outputs allowed on a single fusion node. + // Exposed publicly mainly for tests. + static constexpr int64 kMaxOperandsAndOutputsPerFusion = 64; + + // Determines whether the combination of `a` and `b` into a (possibly + // multi-output) fusion would be "too large" -- i.e., have more operands and + // outputs than is allowed. + // + // `ShouldFuse` and `ShouldFuseIntoMultiOutput` call this; it's public so that + // other fusion passes (e.g. GPU multi-output fusion) can also call this. + static bool FusionWouldBeTooLarge(const HloInstruction* a, + const HloInstruction* b); + static bool IsExpensive(const HloInstruction& instruction); bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; @@ -36,10 +49,6 @@ class GpuInstructionFusion : public InstructionFusion { HloInstruction::FusionKind ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) override; - - // Maximum number of operands allowed on a single fusion node. Exposed - // publicly mainly for tests. - static constexpr int64 kMaxOperandsPerFusion = 64; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 229eb23f124829e490f1fe59fb28fe97dad9211a..8d0522bd8fd6659e64d18c52807df8dc7fc2f3b8 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -609,7 +609,7 @@ TEST_F(InstructionFusionTest, FuseScalarConstant) { // Check that we limit the number of operands to fusions we create. TEST_F(InstructionFusionTest, AvoidsLargeFusion) { constexpr int64 kNumParams = 200; - ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsPerFusion); + ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion); // Compute p0 + p1 + ... + pN. HloComputation::Builder b(TestName()); @@ -631,7 +631,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { SCOPED_TRACE(module->ToString()); for (const HloInstruction* instr : computation->instructions()) { EXPECT_LE(instr->operand_count(), - GpuInstructionFusion::kMaxOperandsPerFusion) + GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion) << instr->ToString(); } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 2799baab41ce15e959be16102b534aac0c8f345a..c349063c71f000435a05306101ad724505f2d197 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -38,24 +38,27 @@ namespace gpu { namespace { // Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) { + return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 && + !LayoutUtil::IsPadded(shape); } // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { + const Shape& output_shape, + int64 batch_dimensions_size) { // The inputs and the output must // 1) be matrices with no padding and a non-zero number of elements, // 2) have an allowed element type. PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64); - return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && - IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape) && + output_primitive_type == F64 || output_primitive_type == C64); + return type_is_allowed && + IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(output_shape, batch_dimensions_size) && !ShapeUtil::IsZeroElementArray(lhs_shape) && !ShapeUtil::IsZeroElementArray(rhs_shape); } @@ -64,14 +67,15 @@ bool DotImplementedAsGemm(const HloInstruction& dot) { CHECK_EQ(dot.opcode(), HloOpcode::kDot); const Shape& lhs_shape = dot.operand(0)->shape(); const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(), + dim_numbers.lhs_batch_dimensions_size())) { // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; @@ -81,11 +85,6 @@ bool DotImplementedAsGemm(const HloInstruction& dot) { } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { - // We can only do this if the HLO is unnested. - if (hlo.parent() != hlo.GetModule()->entry_computation()) { - return false; - } - // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { return DotImplementedAsGemm(hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 9bb4c42b15016a763889a903593cb987afda8da9..5d23a3d01842c7b4ff405171cd49c96a19f7e5b0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -31,6 +31,12 @@ namespace gpu { constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. +// +// Precondition: `hlo` is in an "unnested context", meaning, it lives within the +// entry computation, within the either of a while loop's subcomputations, +// within any of a conditional's subcomputations, etc., but *does not* live +// within a reduce subcomputation, a map subcomputation, a fusion +// subcomputation, etc. It's OK if `hlo` *is* a fusion. bool ImplementedAsGemm(const HloInstruction& hlo); // A call to cuDNN for batch normalization is represented as CustomCall HLO with diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index f95541cba43dc6586ea77cb12404e24d0e08db01..541cacf6970453033c09a153a2dd320d4ebf3d4a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -81,19 +81,6 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { } Status IrEmitter::HandleConstant(HloInstruction* constant) { - const Literal& literal = constant->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl - << " emitted_value: " << llvm_ir::DumpToString(*global_for_const) - << std::endl - << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); return Status::OK(); } @@ -138,6 +125,10 @@ Status IrEmitter::HandleRecvDone(HloInstruction*) { return Unimplemented("Recv-done is not implemented on GPU"); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on GPUs."); +} + Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; for (const HloInstruction* operand : tuple->operands()) { @@ -463,6 +454,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); // TODO(b/110211620): Convert to use i32 index_type when it is possible. llvm::Type* index_type = b_.getInt64Ty(); @@ -498,9 +492,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const int64 lhs_reduction_dimension = ShapeUtil::GetDimensionNumber(lhs_shape, -1); const int64 rhs_reduction_dimension = - ShapeUtil::Rank(rhs_shape) >= 2 + ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size() ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) - : 0; + : dnums.lhs_batch_dimensions_size(); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dnums.lhs_batch_dimensions()) { + CHECK_NE(lhs_reduction_dimension, batch_dim); + CHECK_NE(rhs_reduction_dimension, batch_dim); + } // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -515,6 +515,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + // We don't have to iterate over the batch dimensions in both arrays, simplify + // the loop nest of the rhs. + for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { + DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i)); + rhs_index[i] = lhs_index[i]; + } + // Create the reduction loop which does the sum of products reduction. std::unique_ptr reduction_loop = loop_nest.AddLoop( /*start_index=*/0, @@ -577,7 +584,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { target_index.push_back(lhs_index[dimension]); } } - for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) { + // Skip over the batch dimensions to not have them in the index twice. + for (size_t dimension = dnums.lhs_batch_dimensions_size(); + dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { target_index.push_back(rhs_index[dimension]); } @@ -716,23 +725,6 @@ Status IrEmitter::HandleOutfeed(HloInstruction*) { return Unimplemented("Outfeed is not supported on GPU."); } -Status IrEmitter::HandleRng(HloInstruction* random) { - ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; - for (const HloInstruction* operand : random->operands()) { - operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand, *random).EmitReadArrayElement(index, &b_); - }; - } - // Emits a single-threaded loop because the loop body generated by the element - // generator for Rng can't be parallelized (b/32333178). - return llvm_ir::LoopEmitter( - GpuElementalIrEmitter(hlo_module_config_, module_, &b_, - GetNestedComputer()) - .MakeElementGenerator(random, operand_to_generator), - GetIrArray(*random, *random), &b_) - .EmitLoop(IrName(random)); -} - Status IrEmitter::HandleBatchNormInference(HloInstruction*) { return Unimplemented( "The GPU backend does not implement BatchNormInference directly. It " diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index e89967a378d97d00e0050b9b731795b3cdf51857..561c6838798aa92ce2c96b3c45d5ba42fe6edef3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -86,12 +86,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleRng(HloInstruction* random) override; Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1f31a7f36bbf79aec67a99c0781d139bfa5257e3..a093ffc7c1293d5dc9e44de97896add12a3ae510 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -55,10 +56,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" @@ -66,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -75,6 +77,7 @@ limitations under the License. #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -168,40 +171,6 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { return DfsHloVisitor::Postprocess(hlo); } -namespace { -bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, - const HloInstruction& hlo) { - // `hlo` needs to satisfy the following conditions to be implemented as a - // host-to-device cuMemcpy. - // - // 1. `hlo` is a kCopy instruction. - // 2. `hlo`'s only operand is a kConstant instruction. - // 3. `hlo` and its operand have the same shape (thus the same layout too). - // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing - // pointers in a tuple). - return hlo.opcode() == HloOpcode::kCopy && - hlo.operand(0)->opcode() == HloOpcode::kConstant && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok(); -} - -bool ImplementedAsDeviceToDeviceMemcpy( - const BufferAssignment& buffer_assignment, const HloInstruction& hlo) { - // `hlo` needs to satisfy three conditions to be implemented as a - // device-to-device cuMemcpy. - // - // 1. `hlo` is a kCopy instruction. - // 2. `hlo` and its operand have the same shape (thus the same layout too). - // 3. `hlo` and its operand have a statically-known buffer assignment - // (constants do not, for instance), which means the source buffer also - // resides on the device. - return hlo.opcode() == HloOpcode::kCopy && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() && - buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok(); -} -} // namespace - llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, tensorflow::gtl::ArraySlice args) { @@ -230,11 +199,20 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ++arg_it; kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); + + const int64 alignment = [&] { + if (alloc->is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (alloc->is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + kernel->addParamAttr( - arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, - alloc->is_entry_computation_parameter() - ? kEntryParameterAlignBytes - : kXlaAllocatedBufferAlignBytes)); + arg_no, + llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment)); if (alloc->IsPreallocatedTempBuffer()) { fn_arg->setName("temp_buf"); @@ -367,11 +345,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { - const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 || - dnums.rhs_batch_dimensions_size() > 0) { - return Unimplemented("Dot with batch dimensions not implemented."); - } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); @@ -718,13 +691,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), - *copy)) { - thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); - return Status::OK(); - } - if (ImplementedAsDeviceToDeviceMemcpy( - ir_emitter_context_->buffer_assignment(), *copy)) { + CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape())); + const BufferAssignment& buffer_assignment = + ir_emitter_context_->buffer_assignment(); + if (LayoutUtil::Equal(copy->operand(0)->shape().layout(), + copy->shape().layout()) && + buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) { thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } @@ -1762,6 +1734,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { .GetUniqueTopLevelSlice(tuple_element) .ok(); }); + // TODO(b/111689850): This logic isn't quite correct. + // // Tuples (especially tuples that are the final result of a computation) can // be so huge that if we were to emit a kernel that took each tuple element as // a parameter, we would exceed the max allowable number of parameters to a @@ -1769,9 +1743,9 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { // buffer, we collect their buffer addresses in a host array, and then copy // that array to the tuple's buffer. // - // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back to - // emitting kernels which have access to their buffer addresses in code. + // Some tuple elements might not have an unambiguous buffer (like the result + // of a select-tuple). In that case, we fall back to emitting kernels which + // have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; for (const HloInstruction* tuple_element : tuple->operands()) { @@ -1989,27 +1963,55 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; // Build ForThunk for conformant while loops, otherwise build WhileThunk. - auto result = CanTransformWhileToFor(xla_while); - if (result.ok()) { - auto tuple = result.ConsumeValueOrDie(); - // loop_trip_count = (limit - start + increment - 1) / increment - const int64 loop_trip_count = - (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) / - std::get<2>(tuple); - thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count)); + // TODO(b/112163966): Move trip count computation earlier in the pipeline. + if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) { + thunk_sequence_->emplace_back(BuildForThunk(xla_while, *loop_trip_count)); VLOG(3) << "Built ForThunk for while: " << xla_while->name(); } else { thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); - VLOG(3) << "Built WhileThunk for while: " << xla_while->name() - << " while-to-for transform status: " << result.status(); + VLOG(3) << "Built WhileThunk for while: " << xla_while->name(); } return Status::OK(); } -Status IrEmitterUnnested::HandleRng(HloInstruction* random) { - thunk_sequence_->push_back( - BuildKernelThunk(random, /*implements_whole_instruction=*/true)); - return IrEmitter::HandleRng(random); +Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { + // Build the kernel to generate the random numbers. + // + // Unroll the kernel so that the duplicated computation that calculates the + // 128 bit sample can be optimized away by LLVM. + thunk_sequence_->emplace_back( + BuildKernelThunk(rng, /*implements_whole_instruction=*/false, + ComputeMaxUnrollFactor(rng))); + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : rng->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*operand, *rng).EmitReadArrayElement(index, &b_); + }; + } + TF_RETURN_IF_ERROR(EmitTargetElementLoop( + *rng, GpuElementalIrEmitter(hlo_module_config_, module_, &b_, + GetNestedComputer()) + .MakeElementGenerator(rng, operand_to_generator))); + std::unique_ptr rng_thunk = std::move(thunk_sequence_->back()); + thunk_sequence_->pop_back(); + + // Emit a kernel to increment the global state for Philox RNG algorithm. + thunk_sequence_->emplace_back( + BuildKernelThunk(rng, /*implements_whole_instruction=*/false)); + llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); + std::unique_ptr increment_seed_thunk = + std::move(thunk_sequence_->back()); + thunk_sequence_->pop_back(); + + // Build the SequentialThunk for the RNG hlo. + std::vector> thunks; + thunks.reserve(2); + thunks.push_back(std::move(rng_thunk)); + thunks.push_back(std::move(increment_seed_thunk)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), rng)); + + return Status::OK(); } Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { @@ -2020,28 +2022,34 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; + auto keys = sort->operand(0); auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); if (values != nullptr) { - // TODO(b/26783907): Also sort the values by their corresponding key. - return Unimplemented("Key/Value Sort is not implemented on GPU"); + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); - // First copy the operand to the output, so that we can sort in-place. - // TODO(b/26783907): Share buffer of output and operand when it is possible. - if (sort->operand(0)->IsConstant()) { - thunks.push_back(MakeUnique( - /*source_address=*/sort->operand(0)->literal().untyped_data(), - /*destination_buffer=*/GetAllocationSlice(*sort), - /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); - } else { + if (keys_destination != GetAllocationSlice(*keys)) { thunks.push_back(MakeUnique( - /*source_address=*/GetAllocationSlice(*sort->operand(0)), - /*destination_buffer=*/GetAllocationSlice(*sort), - /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + /*source_address=*/GetAllocationSlice(*keys), + /*destination_buffer=*/keys_destination, + /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); + } + if (values != nullptr && values_destination != GetAllocationSlice(*values)) { + // TODO(b/26783907): Figure out why we never seem to share buffers for + // key/value sort. + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*values), + /*destination_buffer=*/values_destination, + /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); } int64 dimension_to_sort = sort->dimensions(0); - int64 dimension_to_sort_bound = sort->shape().dimensions(dimension_to_sort); + int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); auto index_type = b_.getInt64Ty(); @@ -2065,7 +2073,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { thunks.push_back( BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - sort->shape(), ir_emitter_context_->device_description()); + keys->shape(), ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); @@ -2077,8 +2085,11 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( - dimension_to_sort, GetIrArray(*sort, *sort), IrName(sort), xor_mask, - &b_, &launch_dimensions)); + dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), + values != nullptr ? tensorflow::gtl::make_optional( + GetIrArray(*sort, *sort, values_shape_index)) + : tensorflow::gtl::nullopt, + IrName(sort), xor_mask, &b_, &launch_dimensions)); } } @@ -2240,11 +2251,6 @@ GetHloBufferSlices(const HloInstruction* hlo, // Adds entries for all subshapes of instr to `slices`. auto add_slices_for = [&](const HloInstruction* instr) { - // GPU constants don't have buffers; don't bother looking for one. - if (instr->IsConstant()) { - return; - } - ShapeUtil::ForEachSubshape( instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { if (slices.count({instr, index})) { @@ -2306,21 +2312,25 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. - std::vector buffers(buffers_needed.begin(), - buffers_needed.end()); - std::sort(buffers.begin(), buffers.end(), + std::vector non_constant_buffers; + c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); + + std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { return a->index() < b->index(); }); - llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); + llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); // Build a map from a BufferAllocation to the corresponding argument in our // kernel. std::unordered_map kernel_args; { auto arg_it = kernel->arg_begin(); - auto buffers_it = buffers.begin(); + auto buffers_it = non_constant_buffers.begin(); for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { kernel_args[*buffers_it] = arg_it; } @@ -2338,8 +2348,16 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( << " is found in slice " << slice.ToString() << " at GTE index " << gte_index.ToString(); - llvm::Value* loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + llvm::Value* loc; + if (slice.allocation()->is_constant()) { + loc = ir_emitter_context_->llvm_module()->getGlobalVariable( + llvm_ir::AsStringRef(llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + CHECK_NE(loc, nullptr); + } else { + loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); + } // If gte_index is nonempty, we have to dereference `loc` to get to the // value we're ultimately interested in. @@ -2362,9 +2380,9 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), - implements_whole_instruction ? inst : nullptr, - unroll_factor); + return MakeUnique( + non_constant_buffers, llvm_ir::AsString(kernel->getName()), + implements_whole_instruction ? inst : nullptr, unroll_factor); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -2601,7 +2619,17 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // If the init_value was fused into this reduce we have to generate it first. if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + + const Literal& literal = init_value_operand->literal(); + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + *module_, initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, + /*Name=*/""); + global_for_const->setAlignment(kConstantBufferAlignBytes); + bindings_.BindHloToIrValue(*init_value_operand, global_for_const); } TF_RETURN_IF_ERROR(ParallelLoopEmitter( [=](const IrArray::Index& index) { @@ -2719,13 +2747,13 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( HloComputation* condition = hlo->while_condition(); IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, ir_emitter_context_); - TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition)); + TF_CHECK_OK(condition->Accept(&ir_emitter_condition)); // Generate thunk sequence for while 'body'. HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, ir_emitter_context_); - TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); + TF_CHECK_OK(body->Accept(&ir_emitter_body)); return MakeUnique( GetAllocationSlice(*condition->root_instruction()), // cond result @@ -2743,7 +2771,7 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, ir_emitter_context_); - TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); + TF_CHECK_OK(body->Accept(&ir_emitter_body)); return MakeUnique(loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2759,12 +2787,12 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( HloComputation* true_computation = hlo->true_computation(); IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation, ir_emitter_context_); - TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true)); + TF_CHECK_OK(true_computation->Accept(&ir_emitter_true)); HloComputation* false_computation = hlo->false_computation(); IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation, ir_emitter_context_); - TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false)); + TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); return MakeUnique( GetAllocationSlice(*hlo->operand(0)), @@ -3333,5 +3361,47 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return true; } +Status IrEmitterUnnested::EmitConstantGlobals() { + for (const BufferAllocation& allocation : + ir_emitter_context_->buffer_assignment().Allocations()) { + if (!allocation.is_constant()) { + continue; + } + + const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); + const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); + llvm::ArrayType* global_type = + llvm::ArrayType::get(b_.getInt8Ty(), allocation.size()); + llvm::Constant* initializer = + should_emit_initializer + ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) + : llvm::ConstantAggregateZero::get(global_type); + if (should_emit_initializer) { + VLOG(3) << "Emitted initializer for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + } + + // These globals will be looked up by name by GpuExecutable so we need to + // give them an external linkage. Not all of their uses are visible in the + // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely + // preserves their names (like available_externally), we also need to ensure + // that they stick around even if they're "unused". + // + // We may have to be more more clever here in the future if we notice that + // we're keeping around too many globals because of their linkage. + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + global_type, /*isConstant=*/should_emit_initializer, + llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/initializer, + llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName(allocation))); + global_for_const->setAlignment(kConstantBufferAlignBytes); + ir_emitter_context_->llvm_module()->getGlobalList().push_back( + global_for_const); + } + + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 616d8a2206e5a9666947008879c48f99a022e899..525441990795e160ba0e8facb910d5cc9796c4bb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -92,6 +92,9 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, KernelThunk* thunk); + // Emits LLVM global variables corresponding to constant instructions. + Status EmitConstantGlobals(); + private: // Builds the appropriate thunk for the instruction hlo and returns the owning // pointer to it. The caller needs to make sure `inst` outlives the lifetime diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 6c1c20fc0464927054deace8980620c3a9c6f09b..cf44458a2ed6c069c1469bb975c62565264451c1 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -114,21 +114,20 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Gets the GPU name as it's known to LLVM for a given compute capability. If // we see an unrecognized compute capability, we return "sm_30". static string GetSmName(std::pair compute_capability) { - static auto* m = new std::map, int>( - {{{2, 0}, 20}, - {{2, 1}, 21}, - {{3, 0}, 30}, - {{3, 2}, 32}, - {{3, 5}, 35}, - {{3, 7}, 37}, - {{5, 0}, 50}, - {{5, 2}, 52}, - {{5, 3}, 53}, - {{6, 0}, 60}, - {{6, 1}, 61}, - {{6, 2}, 62}, - // TODO: Change this to 70 once LLVM NVPTX supports it - {{7, 0}, 60}}); + static auto* m = new std::map, int>({ + {{3, 0}, 30}, + {{3, 2}, 32}, + {{3, 5}, 35}, + {{3, 7}, 37}, + {{5, 0}, 50}, + {{5, 2}, 52}, + {{5, 3}, 53}, + {{6, 0}, 60}, + {{6, 1}, 61}, + {{6, 2}, 62}, + {{7, 0}, 70}, + {{7, 2}, 72}, + }); int sm_version = 30; auto it = m->find(compute_capability); if (it != m->end()) { @@ -329,7 +328,7 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, if (linker.linkInModule( std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded, [](Module& M, const StringSet<>& GVS) { - internalizeModule(M, [&M, &GVS](const GlobalValue& GV) { + internalizeModule(M, [&GVS](const GlobalValue& GV) { return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index f95fbb01f9ddf1284c44dc0f576d3e79139c9df1..c62bae0628f7b2fbfe822104fbe5f3528e0e09c3 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -112,17 +113,25 @@ bool IsInputFusibleReduction(HloInstruction* instr) { // of input parameters differ. In such situtations it is beneficial not to fuse. // We consider input params with maximum rank only. Params with smaller ranks // will be broadcasted and have not been observed to cause data locality issues. -// TODO(b/110927656): Improve reduce emitters to remove this limitation. +// TODO(b/111977086): Improve reduce emitters to remove this limitation. bool ReduceFriendlyInputLayouts(HloInstruction* instr) { + std::vector params; + if (instr->opcode() == HloOpcode::kFusion) { + params = instr->fused_parameters(); + } else { + for (HloInstruction* operand : instr->operands()) { + params.push_back(operand); + } + } int64 max_rank = 0; const Layout* max_rank_layout; - for (HloInstruction* param : instr->fused_parameters()) { + for (HloInstruction* param : params) { if (ShapeUtil::Rank(param->shape()) > max_rank) { max_rank = ShapeUtil::Rank(param->shape()); max_rank_layout = ¶m->shape().layout(); } } - return c_all_of(instr->fused_parameters(), [&](HloInstruction* param) { + return c_all_of(params, [&](HloInstruction* param) { return (ShapeUtil::Rank(param->shape()) < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); @@ -163,16 +172,22 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) { return false; } + // If we're fusing fusions only do it if the fusion kind matches. Loop fusions // merge into bigger loop fusions and input (reduce) fusions become fusions // with multiple reduce outputs. We could fuse reduce and loop fusions // together too (the result being an input fusion) if we find cases where this // improves things. CHECK(instr1->opcode() == HloOpcode::kFusion); - if (instr2->opcode() == HloOpcode::kFusion) { - return instr1->fusion_kind() == instr2->fusion_kind(); + if ((instr2->opcode() == HloOpcode::kFusion && + instr1->fusion_kind() != instr2->fusion_kind()) || + (instr2->opcode() != HloOpcode::kFusion && + instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { + return false; } - return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop; + + // Do this check last, as it may be expensive. + return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); } bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { @@ -214,7 +229,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { const bool is_loop_fusion = producer->opcode() == HloOpcode::kFusion && producer->fusion_kind() == HloInstruction::FusionKind::kLoop; - if (!is_loop_fusion) { + if (!producer->IsElementwise() && !is_loop_fusion) { VLOG(3) << producer->name() << " is not a loop fusion."; continue; } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 451e49f23a3404ab96524fcd3f2d721effda8f31..14f157a5e518a0ec82c664c123629d04bd385bbf 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -255,6 +256,26 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Exp())); +} + TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_add { @@ -419,5 +440,58 @@ TEST_F(MultiOutputFusionTest, ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); } +// Check that we limit the number of operands to fusions we create. +TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { + constexpr int64 kNumParams = 200; + ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion); + + // Compute + // p0 * p1, + // p0 * p1 + p1 * p2 + // p0 * p1 + p1 * p2 + p2 * p3 + // ... + // where each of the (pi * pj)'s is represented as a fusion node so that + // multi-output fusion will pay attention to it. + auto module = CreateNewModule(); + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); + + std::vector params; + for (int64 i = 0; i < kNumParams; ++i) { + params.push_back( + b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"))); + } + + // Creates a fusion node that calculates x*y. + auto make_fusion = [&](HloInstruction* x, HloInstruction* y) { + HloComputation::Builder sub_builder("subcomp"); + auto* p0 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "p")); + auto* p1 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "p")); + sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); + HloComputation* subcomp = + module->AddEmbeddedComputation(sub_builder.Build()); + return HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp); + }; + + auto* sum = b.AddInstruction(make_fusion(params[0], params[1])); + for (int64 i = 2; i < kNumParams; ++i) { + sum = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, sum, + b.AddInstruction(make_fusion(params[i - 1], params[i])))); + } + auto computation = module->AddEntryComputation(b.Build()); + EXPECT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + for (const HloInstruction* instr : computation->instructions()) { + EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()), + GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion) + << instr->ToString(); + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 2eefadebcd1098b294c79d6e400857b1c2760824..76c9b6ab33befa98f03821fac84071bd978ae24d 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" @@ -52,9 +51,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -74,7 +75,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" -#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -146,7 +146,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // support BF16 operations without directly implementing a BF16 lowering for // most ops. pipeline.AddPass(BF16, F32); - pipeline.AddPass(); { auto& pass = @@ -199,6 +198,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass(); + if (IsVoltaOrLater(*stream_exec)) { + pipeline.AddPass(); + // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element + // pairs that TupleSimplifier fixes. + pipeline.AddPass(); + } TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -275,14 +280,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } } - { - // Do an aggressive LICM pass over while loops. In particular, this hoists - // constants that were sunk by WhileLoopConstantSinking. Leaving them in - // the while loop may result in unnecessary copies. - HloPassPipeline pipeline("while-loop-licm"); - pipeline.AddPass(true); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } return Status::OK(); } @@ -540,11 +537,13 @@ StatusOr> NVPTXCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, - BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/[](LogicalBuffer::Color) { - return kXlaAllocatedBufferAlignBytes; - })); + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); @@ -565,6 +564,9 @@ StatusOr> NVPTXCompiler::RunBackend( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, &ir_emitter_context); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission"); TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h index a752eb70119b00e8cca7ddce26da7730ef5db8cb..160ba4b691f818ff01b41b8603c11853ea12c253 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h @@ -36,22 +36,19 @@ class OutfeedBuffer { OutfeedBuffer(int64 length) : length_(length) {} // Waits for the device transfer to be finished. - std::unique_ptr WaitUntilAvailable() { - done_.WaitForNotification(); - return std::move(destination_); - } + void WaitUntilAvailable() { done_.WaitForNotification(); } int64 length() const { return length_; } - void set_destination(std::unique_ptr destination) { + void set_destination(std::unique_ptr destination) { destination_ = std::move(destination); } - Literal* destination() { return destination_.get(); } + MutableBorrowingLiteral* destination() { return destination_.get(); } // Callback to signal that this buffer is consumed. void Done() { done_.Notify(); } private: - std::unique_ptr destination_; + std::unique_ptr destination_; const int64 length_; tensorflow::Notification done_; }; diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 7986e63f43ee508370f94fdb9057b91bfe4add18..b99d998c4d7df514c024b1f8d643d08c72059d0e 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -50,10 +50,6 @@ Status OutfeedThunk::ExecuteOnStream( if (!*buffer) { // Tuple pointers. return Status::OK(); } - // Allocate storage for the literal data. - const Shape& shape = - ShapeUtil::GetSubshape(outfeed_buffers->shape(), index); - (*buffer)->set_destination(Literal::CreateFromShape(shape)); BufferAllocation::Slice slice = outfeed_slices_.element(index); se::DeviceMemoryBase data_address; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc new file mode 100644 index 0000000000000000000000000000000000000000..79f7d31816baf0b95b967771b956a9c06ac81e91 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -0,0 +1,233 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace gpu { + +using tensorflow::gtl::ArraySlice; + +// We want the input/output feature counts of an f16 conv to be factors of 8, +// because without this cudnn can't use tensor cores on the conv. +static constexpr int64 kDesiredNumFeaturesFactor = 8; + +// We won't pad a conv if doing so increases the total number of bytes in the +// lhs, rhs, or result by more than this amount. +// +// TODO(jlebar): This number was tuned experimentally. It represents a +// compromise on our current benchmarks; it speeds some up significantly, and +// doesn't slow any down. But we can observe by changing this value that +// there's additional room for speedups. Achieving those speedups without also +// slowing other things down will likely require a more sophisticated heuristic, +// possibly some form of auto-tuning. +static constexpr double kMaxBytesTouchedIncrease = 1.2; + +// Pads the given dimensions in the given shape up to a multiple of +// kDesiredNumFeaturesFactor. +static Shape PadShape(Shape s, ArraySlice dims) { + for (int64 dim : dims) { + int64 dim_to_pad_size = s.dimensions(dim); + int64 new_dim_to_pad_size = + RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + s.set_dimensions(dim, new_dim_to_pad_size); + } + return s; +} + +// Creates and returns an HLO that zero-pads one or more dimensions in the given +// instruction so that its shape is equal to the given shape. +// +// Padding is added to the end of each relevant dimension. +// +// If the instruction already has the given shape, simply returns it without an +// intervening pad. +static HloInstruction* PadInstruction(HloInstruction* instr, + const Shape& new_shape) { + HloComputation* comp = instr->parent(); + + const Shape& shape = instr->shape(); + auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + + PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + + bool added_padding = false; + for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + if (shape.dimensions(dim) == new_shape.dimensions(dim)) { + continue; + } + CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim)); + pad_config.mutable_dimensions(dim)->set_edge_padding_high( + new_shape.dimensions(dim) - shape.dimensions(dim)); + added_padding = true; + } + + if (!added_padding) { + return instr; + } + return comp->AddInstruction( + HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); +} + +// Pads the input/output feature dimensions of the given cudnn convolution +// custom-call to be multiples of kDesiredNumFeaturesFactor. +static StatusOr PadFeaturesDims(HloInstruction* conv) { + CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) + << "conv must use 0 scratch bytes, i.e. this pass must be run " + "before CudnnConvolutionAlgorithmPicker."; + + const auto& target = conv->custom_call_target(); + const auto& dnums = conv->convolution_dimension_numbers(); + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + Shape new_lhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardFilterCallTarget) { + // LHS is "input". + return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); + // LHS is "output". + return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); + }(); + + Shape new_rhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget) { + // RHS is "filter". + return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // RHS is "output". + return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); + }(); + + if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && + ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + return false; + } + + Shape new_result_shape = [&] { + if (target == kCudnnConvForwardCallTarget) { + // Result is "output". + return PadShape(result_shape, {dnums.output_feature_dimension()}); + } + if (target == kCudnnConvBackwardInputCallTarget) { + // Result is "input". + return PadShape(result_shape, {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // Result is "filter". + return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + }(); + + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return true; + } + VLOG(3) << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + return false; + }; + if (!check_size_increase(lhs->shape(), new_lhs_shape) || + !check_size_increase(rhs->shape(), new_rhs_shape) || + !check_size_increase(result_shape, new_result_shape)) { + return false; + } + + // OK, let's do the transformation! + + auto* new_lhs = PadInstruction(lhs, new_lhs_shape); + auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + CHECK(new_lhs != lhs || new_rhs != rhs) + << "We should have had to pad either LHS or RHS."; + + auto add = [&](std::unique_ptr new_instr) { + return conv->parent()->AddInstruction(std::move(new_instr)); + }; + + Shape new_conv_shape = ShapeUtil::MakeTupleShape( + {new_result_shape, ShapeUtil::MakeShape(U8, {0})}); + auto* new_conv = + add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs})); + + // Slice the new conv result if necessary, keeping in mind that new_conv has + // tuple shape (new_result_shape, u8[0]). + if (!ShapeUtil::Equal(result_shape, new_result_shape)) { + std::vector start_indices(result_shape.dimensions_size(), 0); + std::vector end_indices(result_shape.dimensions().begin(), + result_shape.dimensions().end()); + std::vector strides(result_shape.dimensions_size(), 1); + + auto* new_conv_result = add( + HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0)); + auto* empty_temp_buffer = + add(HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + auto* sliced_result = add(HloInstruction::CreateSlice( + result_shape, new_conv_result, start_indices, end_indices, strides)); + new_conv = + add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer})); + } + + VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " + << new_conv->ToString(); + TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv)); + return true; +} + +static std::vector GetRelevantConvs(HloComputation* comp) { + std::vector convs; + for (HloInstruction* instr : comp->instructions()) { + if (IsCustomCallToDnnConvolution(*instr) && + instr->operand(0)->shape().element_type() == F16) { + convs.push_back(instr); + } + } + return convs; +} + +StatusOr PadForTensorCores::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloInstruction* conv : GetRelevantConvs(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); + changed |= result; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h new file mode 100644 index 0000000000000000000000000000000000000000..192359f026bfb2f1d5436713e4a30725fa0ad6ba --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Ensures that f16 cudnn convolutions have input/output channel dimensions that +// are multiples of 8, inserting pads/slices as necessary. +// +// This is useful primarily for Volta and newer GPUs, where tensor cores can +// only be used if the channel dims are multiples of 8. It's probably the +// opposite of useful on other GPUs, so you should check what GPU you're +// targeting before running this pass. +// +// TODO(jlebar): Also pad dots. +class PadForTensorCores : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "pad for tensor cores"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..99e7580b826fc5cd6d98a037a5eb064552952e18 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +using PadForTensorCoresTest = HloVerifiedTestBase; + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + filter = f16[2,2,41,40] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + + SCOPED_TRACE(module().ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,41] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); +} + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvForwardCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _)); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,40] parameter(0) + filter = f16[2,2,41,40] parameter(1) + result = (f16[10,20,30,41], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardInputCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + output = f16[10,20,30,40] parameter(1) + result = (f16[2,2,41,40], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Pad(op::Parameter(0), _), op::Parameter(1)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + output = f16[10,20,30,41] parameter(1) + result = (f16[2,2,40,41], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Parameter(0), op::Pad(op::Parameter(1), _)))), + _))); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index e4cfc6999f2da04dd7e7a34d854fdb3d75b8bfc6..0806dd51614f4d2da12f3fbbc9fb98df5273d5c8 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -33,13 +33,13 @@ int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { } void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo, - int stream_no) { - CHECK_GE(stream_no, 0); - if (stream_no >= stream_count_) { - stream_count_ = stream_no + 1; + int stream_num) { + CHECK_GE(stream_num, 0); + if (stream_num >= stream_count_) { + stream_count_ = stream_num + 1; } - InsertOrDie(&hlo_to_stream_number_, hlo, stream_no); - VLOG(2) << "Assign stream #" << stream_no << " to " << hlo->ToString(); + InsertOrDie(&hlo_to_stream_number_, hlo, stream_num); + VLOG(2) << "Assign stream #" << stream_num << " to " << hlo->ToString(); } namespace { @@ -51,6 +51,12 @@ bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b, return !reachability.IsConnected(&a, &b); } +constexpr int kInvalidStreamNum = -1; +// Returns true iff `stream_num` is an invalid stream number. +inline bool IsStreamNumValid(int stream_num) { + return stream_num != kInvalidStreamNum; +} + // Returns which existing stream to assign to `hlo`, or -1 if a stream is not // needed. `stream_assignment` is the existing stream assignment for all // instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that @@ -62,7 +68,7 @@ int ComputeStreamToAssign( if (hlo.opcode() == HloOpcode::kParameter || hlo.opcode() == HloOpcode::kConstant) { // kParameter and kConstant do not need a thunk. - return -1; + return kInvalidStreamNum; } if (hlo.GetModule() @@ -75,17 +81,17 @@ int ComputeStreamToAssign( if (!ImplementedAsGemm(hlo)) { // If `hlo` is not implemented as a GEMM, keep it close to its operands to // avoid excessive synchronization. - int stream_no = -1; + int stream_num = -1; for (const auto* operand : hlo.operands()) { if (stream_assignment.HasStreamAssigned(*operand)) { - stream_no = - std::max(stream_no, stream_assignment.StreamNumberForHlo(*operand)); + stream_num = std::max(stream_num, + stream_assignment.StreamNumberForHlo(*operand)); } } - if (stream_no == -1) { - stream_no = 0; + if (!IsStreamNumValid(stream_num)) { + stream_num = 0; } - return stream_no; + return stream_num; } // Assign different streams to concurrent GEMMs. The code below uses a @@ -94,17 +100,17 @@ int ComputeStreamToAssign( // `hlo` a different stream. std::set forbidden_stream_numbers; for (const auto* seen_gemm : seen_gemms) { - int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm); - if (!forbidden_stream_numbers.count(stream_no) && + int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm); + if (!forbidden_stream_numbers.count(stream_num) && CanRunConcurrently(*seen_gemm, hlo, reachability)) { - forbidden_stream_numbers.insert(stream_no); + forbidden_stream_numbers.insert(stream_num); } } - for (int stream_no = 0; stream_no < stream_assignment.StreamCount(); - ++stream_no) { - if (!forbidden_stream_numbers.count(stream_no)) { - return stream_no; + for (int stream_num = 0; stream_num < stream_assignment.StreamCount(); + ++stream_num) { + if (!forbidden_stream_numbers.count(stream_num)) { + return stream_num; } } return stream_assignment.StreamCount(); @@ -118,11 +124,27 @@ std::unique_ptr AssignStreams(const HloModule& module) { std::unique_ptr reachability = computation.ComputeReachability(); std::vector seen_gemms; + // The execution of different RNG Hlo instructions in the same module updates + // a common global variable. To avoid a race condition, we simply assign all + // RNG kernels to the same stream to make them run sequentially. + // + // TODO(b/111791052): If we remove such a common variable, we will need to + // clean up the code here. + int stream_num_for_rng = kInvalidStreamNum; for (const auto* hlo : computation.MakeInstructionPostOrder()) { - int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment, - *reachability, seen_gemms); - if (stream_no != -1) { - stream_assignment->AssignStreamToHlo(hlo, stream_no); + // If we ever enable fusion of RNG instructions, we will need to extend this + // code to look inside a fused instruction. + int stream_num = (hlo->opcode() == HloOpcode::kRng && + IsStreamNumValid(stream_num_for_rng)) + ? stream_num_for_rng + : ComputeStreamToAssign(*hlo, *stream_assignment, + *reachability, seen_gemms); + if (IsStreamNumValid(stream_num)) { + stream_assignment->AssignStreamToHlo(hlo, stream_num); + if (hlo->opcode() == HloOpcode::kRng && + !IsStreamNumValid(stream_num_for_rng)) { + stream_num_for_rng = stream_num; + } } if (ImplementedAsGemm(*hlo)) { seen_gemms.push_back(hlo); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index a50ddf6ac63c7fa7ccace94bc7f40f438aedccf8..05b305ea4cdfdbaeb42544b626a6b9990bb42f57 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -20,10 +20,17 @@ limitations under the License. namespace xla { namespace gpu { -using stream_executor::dnn::DataLayout; -using stream_executor::dnn::DataLayoutString; -using stream_executor::dnn::FilterLayout; -using stream_executor::dnn::FilterLayoutString; +using se::dnn::DataLayout; +using se::dnn::DataLayoutString; +using se::dnn::FilterLayout; +using se::dnn::FilterLayoutString; + +bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 39a6a38d001f502b2abb8de6efe2ce623b478c71..1fc46bafa10e7ba6c896f081d5c836bd400886c9 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -25,18 +26,20 @@ limitations under the License. namespace xla { namespace gpu { +// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU. +bool IsVoltaOrLater(const se::StreamExecutor& stream_exec); + // Returns (input, filter, output) XLA Layout protos given the StreamExecutor // layouts. StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, - stream_executor::dnn::DataLayout input, - stream_executor::dnn::FilterLayout filter, - stream_executor::dnn::DataLayout output); + se::dnn::DataLayout input, + se::dnn::FilterLayout filter, + se::dnn::DataLayout output); // Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. -StatusOr> +StatusOr< + std::tuple> XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, const Layout& input, const Layout& filter, const Layout& output); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 686c3c16c97151d5ac06983f8709f1d367eb596c..4fad3f46cf953945e4f395e751e5ba76db97ecc4 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -111,8 +111,8 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index ba5cd2d84dfc0cd1515875e8510c18d89e4ec5f7..9072b30317d253fd6d50e9d98949cad4eaebfe7b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index a10e40451c1db01ce73db7b56a3a0599769fa49b..8579b1545fd24f80621ac0f53b997e33586cbabe 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -24,24 +24,32 @@ namespace gpu { Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - std::vector tuple_element_buffer_addresses; - for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { - tuple_element_buffer_addresses.push_back( - buffer_allocations.GetDeviceAddress(tuple_element_buffer).opaque()); + auto size = tuple_element_buffers_.size(); + auto tuple_element_buffer_addresses = MakeUnique(size); + for (int i = 0; i != size; ++i) { + tuple_element_buffer_addresses[i] = + buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque(); } se::DeviceMemory dest_buffer_address( buffer_allocations.GetDeviceAddress(dest_buffer_)); - auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*); + auto host_size = size * sizeof(void*); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream ->ThenMemcpy(&dest_buffer_address, - tuple_element_buffer_addresses.data(), host_size) + tuple_element_buffer_addresses.get(), host_size) .ok()) { return InternalError( "Unable to launch MemcpyH2D from %p to %p with size %lu", - tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), - sizeof(void*) * tuple_element_buffer_addresses.size()); + tuple_element_buffer_addresses.get(), dest_buffer_address.opaque(), + host_size); + } + // Free the tuple address buffer when memcpy is done. + auto* buffers_raw = tuple_element_buffer_addresses.release(); + if (!stream->ThenDoHostCallback([buffers_raw] { delete[] buffers_raw; }) + .ok()) { + delete[] buffers_raw; + return InternalError("Unable to enqueue host callback!"); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 1315a4183a98d6ea9ed4c82d4c22e77c2109ec83..d81d87e7dc54cd752000b85f3ec173d66d7195e4 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -57,6 +57,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, while (true) { // Invoke thunk sequence for while 'condition' computation. profiler->StartHloComputation(); + VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream( buffer_allocations, stream, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_condition()); @@ -64,6 +65,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); + VLOG(3) << "condition_result = " << condition_result; Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError( @@ -78,6 +80,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // We measure the time of one execution of the while body computation. The // while body may be executed more than once, the last measurement "wins". profiler->StartHloComputation(); + VLOG(3) << "Executing body computation"; // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc deleted file mode 100644 index c5321df6c466fcb3816fb2aedad65b7c3811cb37..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ /dev/null @@ -1,521 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - -#include -#include - -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { -namespace gpu { - -namespace { - -// TODO(b/33483676) Use an expression tree to specify computations to pattern -// match for while transformations. - -// ExprTree is a simple recursive data structure used to express computation -// patterns to match. -// -// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and -// HloOpcode of the operand. -// -// For example, the following computation: -// -// Parameter -// | -// Const GetTupleElement -// \ / -// Add (root) -// -// Can be matched with the following expression tree: -// -// ExprTree add(HloOpcode::kAdd, -// ExprTree(HloOpcode::kConstant), -// ExprTree(HloOpcode::kGetTupleElement, -// tuple_index, ExprTree(HloOpcode::kParameter))); -// -// Match the ExprTree root against an Hlo graph: -// -// ExprTree::TaggedInstructionMap tagged_instructions; -// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(), -// &tagged_instructions)); -// -// Instructions that are "tagged" with a context-specific string will -// be returned in 'tagged_instructions' for further processing (i.e. parsing -// constants or recording the tuple_index). -// -class ExprTree { - public: - explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {} - ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {} - ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0) - : opcode_(opcode) { - SetOperand(index0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0, - int64 index1, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(index0, operand0); - SetOperand(index1, operand1); - } - ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0) - : opcode_(opcode), tag_(tag) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(0, operand0); - SetOperand(1, operand1); - } - - ExprTree(const ExprTree& to_copy) { - opcode_ = to_copy.opcode_; - tag_ = to_copy.tag_; - if (to_copy.fused_root_tree_ != nullptr) { - fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_)); - } - for (auto& pair : to_copy.operands_) { - CHECK(operands_.find(pair.first) == operands_.end()); - operands_.insert(std::make_pair( - pair.first, std::unique_ptr(new ExprTree(*pair.second)))); - } - } - - void SetFusedRoot(const ExprTree& fused_root) { - fused_root_tree_.reset(new ExprTree(fused_root)); - } - - typedef std::unordered_map - TaggedInstructionMap; - - // Matches 'instruction' HloOpcode against 'opcode_'. - // Recursively matches each operand in 'operands_'. - // Recursively matches fused instructions starting at 'fused_root_tree_' - // if 'opcode_ == kFusion'. - // Returns OK status, and instructions in 'tagged_instructions' for each - // matched ExprTree node with a non-empty 'tag_'. - // Returns error message on failure. - Status Match(const HloInstruction* instruction, - TaggedInstructionMap* tagged_instructions) const { - if (opcode_ != instruction->opcode()) { - return InvalidArgument("got opcode %s, want %s", - HloOpcodeString(instruction->opcode()).c_str(), - HloOpcodeString(opcode_).c_str()); - } - - VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; - if (!tag_.empty()) { - tagged_instructions->insert({tag_, instruction}); - } - - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK(fused_root_tree_ != nullptr); - // Match fused instructions for this node starting a 'fused_root_tree'. - TF_RETURN_IF_ERROR(fused_root_tree_->Match( - instruction->fused_expression_root(), tagged_instructions)); - } - - // Match each operand in 'operands_'. - for (auto& pair : operands_) { - TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), - tagged_instructions)); - } - return Status::OK(); - } - - private: - void SetOperand(int64 index, const ExprTree& operand) { - CHECK_EQ(0, operands_.count(index)); - operands_.insert(std::make_pair(index, MakeUnique(operand))); - } - - HloOpcode opcode_; - std::unordered_map> operands_; - std::unique_ptr fused_root_tree_; - string tag_; -}; - -// MatcherBase is a base class that provides common functionality for -// sub-classes which match specific target sub-computations (i.e. loop -// induction variable initialization, comparison and update). -class MatcherBase { - public: - MatcherBase() {} - virtual ~MatcherBase() {} - - // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first successful match, error status otherwise. - virtual Status Run() { - Status status; - for (const ExprTree& expr_tree : expr_trees_) { - status = MatchExprTree(expr_tree); - if (status.ok()) { - return status; - } - } - return status; - } - - virtual Status MatchExprTree(const ExprTree& expr_tree) = 0; - - // Returns the constant value parsed form kConstant 'instruction'. - // Returns error status otherwise. - Status ParseConstInteger(const HloInstruction* instruction, - int64* const_value) const { - CHECK_EQ(HloOpcode::kConstant, instruction->opcode()); - PrimitiveType element_type = instruction->shape().element_type(); - if (element_type != S32 && element_type != S64) { - return InvalidArgument("Expected constant of integral type."); - } - const Literal& literal = instruction->literal(); - PrimitiveType type = literal.shape().element_type(); - if (type != S32 && type != S64) { - return InvalidArgument("Must use S32 or S64 integral types."); - } - if (type == S32) { - *const_value = static_cast(literal.GetFirstElement()); - } else if (type == S64) { - *const_value = literal.GetFirstElement(); - } - return Status::OK(); - } - - StatusOr GetTaggedInstruction( - const string& tag, - const ExprTree::TaggedInstructionMap& tagged_instructions) { - auto it = tagged_instructions.find(tag); - if (it == tagged_instructions.end()) { - return InvalidArgument("Cound not find instruction for tag: %s", - tag.c_str()); - } - return it->second; - } - - protected: - std::vector expr_trees_; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); -}; - -// WhileConditionComputationMatcher attempts to match a target computation -// pattern in the while condition sub-computation. -// If the target pattern is matched, two pieces of information are extracted -// from 'tagged' instructions returned by the matcher: -// -// *) 'tuple_index': -// *) The loop induction variable tuple_index from the GetTupleElement -// instruction of the matched computation. -// *) Used in subsequent matching passes of while init operand and body -// computations to select loop induction variable tuple element. -// -// *) 'loop_limit': -// *) The integral value from Constant root operand in matched computation. -// *) Used as the constant for the loop limit. -// -class WhileConditionComputationMatcher : public MatcherBase { - public: - explicit WhileConditionComputationMatcher(const HloComputation* computation) - : computation_(computation) { - expr_trees_.emplace_back(BuildCondExprTree()); - } - - int64 loop_limit() const { return loop_limit_; } - int64 tuple_index() const { return tuple_index_; } - - private: - // Builds expression tree for the following condition computation: - // - // Const Parameter - // \ / - // Fusion ------------> FusionParam FusionParam - // \ / - // GTE / - // \ / - // LessThan (fused root) - // - ExprTree BuildCondExprTree() { - // Build ExprTree for fused instructions. - ExprTree fused_root( - HloOpcode::kLt, - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")), - ExprTree(HloOpcode::kParameter)); - - // Build top-level computation. - ExprTree root(HloOpcode::kFusion, - ExprTree(HloOpcode::kConstant, "loop_limit"), - ExprTree(HloOpcode::kParameter, "param0")); - - root.SetFusedRoot(fused_root); - return root; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while condition"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - // Get tagged GTE instruction and set 'tuple_index_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* gte, - GetTaggedInstruction("gte", tagged_instructions)); - tuple_index_ = gte->tuple_index(); - - // Get tagged Constant instruction and parse 'loop_limit_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_limit", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_)); - - // Get tagged "param0" instruction, and check that it matches - // 'computation_' parameter 0. - TF_ASSIGN_OR_RETURN(const HloInstruction* param0, - GetTaggedInstruction("param0", tagged_instructions)); - if (param0 != computation_->parameter_instruction(0)) { - return InvalidArgument("Unexpected Parameter0 instruction : %s", - param0->name().c_str()); - } - - // Get tagged 'gte.fusion_param.param0', find its associated fusion operand, - // and compare it to 'computation_' parameter0. - TF_ASSIGN_OR_RETURN( - const HloInstruction* gte_fusion_param0, - GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); - CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); - CHECK(gte_fusion_param0->IsFused()); - if (gte_fusion_param0->parent()->FusionInstruction()->operand( - gte_fusion_param0->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - gte_fusion_param0->name().c_str()); - } - - return Status::OK(); - } - - const HloComputation* computation_; - - int64 loop_limit_ = -1; - int64 tuple_index_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); -}; - -// WhileInitOperandMatcher matches a target computation pattern of the -// while instructions 'init' operand, indexing the tuple at 'tuple_index'. -// On success, parses constant 'loop_start' which represents the loop induction -// variable start values, then returns OK. -// Returns error status otherwise. -class WhileInitOperandMatcher : public MatcherBase { - public: - WhileInitOperandMatcher(const HloInstruction* while_hlo, - const int64 tuple_index) - : while_hlo_(while_hlo), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildInitExprTree()); - } - - int64 loop_start() const { return loop_start_; } - - private: - // Builds expression tree for the following while init operand subcomputation: - // - // Const - // | - // Copy - // | - // Tuple0 - // | - // While - // - ExprTree BuildInitExprTree() { - return ExprTree( - HloOpcode::kWhile, "while", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, - ExprTree(HloOpcode::kConstant, "loop_start")))); - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while init"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); - - // Get tagged while instruction check against 'while_hlo_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo, - GetTaggedInstruction("while", tagged_instructions)); - if (while_hlo != while_hlo_) { - return InvalidArgument("Expected While for instruction : %s", - while_hlo->name().c_str()); - } - - // Get tagged Constant instruction and parse 'loop_start_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_start", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - - return Status::OK(); - } - - const HloInstruction* while_hlo_; - const int64 tuple_index_; - - int64 loop_start_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); -}; - -// WhileBodyComputationMatcher matches a target computation pattern for -// the loop induction variable update. Matching proceeds from the while body -// computation root[tuple_index] to param[tuple_index], where 'tuple_index' -// If the target pattern is matched, parses a constant which represents the -// loop induction variable increment value, then returns status OK. -// Returns error status otherwise. -class WhileBodyComputationMatcher : public MatcherBase { - public: - WhileBodyComputationMatcher(const HloComputation* computation, - const int64 tuple_index) - : computation_(computation), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildBodyExprTree(0, 1)); - expr_trees_.emplace_back(BuildBodyExprTree(1, 0)); - } - - int64 loop_increment() const { return loop_increment_; } - - private: - // Builds expression tree for the following while body computation: - // - // - // FusionParam FusionParam - // \ / - // Const Param \ GTE1 - // \ / \ / - // Fusion -----------> Add - // | - // Copy - // | - // Tuple0 - // - ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) { - // Build ExprTree for fused instructions. - ExprTree gte1 = - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")); - ExprTree fused_root(HloOpcode::kAdd, const_index, - ExprTree(HloOpcode::kParameter), gte_index, gte1); - - // Build fusion instruction (and set fused root). - ExprTree fusion(HloOpcode::kFusion, 0, - ExprTree(HloOpcode::kConstant, "loop_increment"), 1, - ExprTree(HloOpcode::kParameter, "param0")); - fusion.SetFusedRoot(fused_root); - - // Build top-level computation. - ExprTree tuple0(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, fusion)); - return tuple0; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while body"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - for (const auto& pair : tagged_instructions) { - const auto& tag = pair.first; - const auto& inst = pair.second; - - if (tag == "gte" && inst->tuple_index() != tuple_index_) { - // Check that the matched GTE instruction is at the 'tuple_index' we - // matched in the while condition computation. - return InvalidArgument("Unexpected tuple index instruction : %s", - inst->name().c_str()); - } else if (tag == "loop_increment") { - // ParseHloString the constant which represents the loop induction - // variable increment value. - TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); - } else if (tag == "param0" && - inst != computation_->parameter_instruction(0)) { - // Check that the matched parameter == parameter 0 from 'computation_'. - return InvalidArgument("Unexpected Parameter0 instruction : %s", - inst->name().c_str()); - } else if (tag == "gte.fusion_param.param0") { - // Fusion parameter: lookup and compare with associated fusion operand. - CHECK_EQ(HloOpcode::kParameter, inst->opcode()); - CHECK(inst->IsFused()); - if (inst->parent()->FusionInstruction()->operand( - inst->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - inst->name().c_str()); - } - } - } - return Status::OK(); - } - - const HloComputation* computation_; - const int64 tuple_index_; - - int64 loop_increment_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); -}; - -} // namespace - -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo) { - if (while_hlo->opcode() != HloOpcode::kWhile) { - return InvalidArgument("Expected While instruction."); - } - - WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); - TF_RETURN_IF_ERROR(cond_matcher.Run()); - - WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(init_matcher.Run()); - - WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), - cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(body_matcher.Run()); - - // Check for valid For loop parameters. - if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { - return InvalidArgument("Loop start must be less than loop limit."); - } - if (body_matcher.loop_increment() <= 0) { - return InvalidArgument("Loop increment must greater than zero."); - } - return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), - body_matcher.loop_increment()); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h deleted file mode 100644 index fe3a954e1828ee4a323872eea81f64c7e780ad24..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace gpu { - -// Runs an analysis of the while loop instruction 'while_hlo' (and its -// associated sub-computations) to determine if it can be transformed into an -// equivalent "for" loop with the following "for" loop parameters: -// -// *) 'loop_start': loop induction variable starting value. -// *) 'loop_limit': loop induction variable limit value. -// *) 'loop_increment': loop induction variable per-iteration increment value. -// -// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success. -// The values in the returned tuple are values extracted from the 'while_hlo' -// operand (and its sub-computations) during analysis. -// Returns an error status on failure. -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index dbc8442ed2785a112b674632689256c01282156b..c5f3906356d821e059d2b1213c9083c4408a4d1c 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -110,12 +109,12 @@ class WhileTransformerTest : public HloTestBase { void RunFusionPasses() { // Run standard fusion passes. - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .status()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .status()); } void RunCopyInsertionPass() { @@ -141,10 +140,7 @@ class WhileTransformerTest : public HloTestBase { Shape condition_result_shape_; }; -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -153,18 +149,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -173,19 +164,14 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { - // Build computation with invalid loop limit. +TEST_F(WhileTransformerTest, ImpossibleLoopLimit) { + // Build computation with an impossible loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); @@ -193,17 +179,13 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop start must be less than loop limit.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(0, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { +TEST_F(WhileTransformerTest, InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -212,11 +194,9 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop increment must greater than zero.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_FALSE(result); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 63a8a813cddf304e60fa9b4bbf709eca2d7c2cae..be9098f555e78f3cabfe55481356f8b6841a3a2b 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -151,8 +151,11 @@ message HloInstructionProto { // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; - // Cross Replica Sum fields. + // Cross replica op fields. + // TODO(b/112107579): remove replica_group_ids field and always use + // replica_groups. repeated int64 replica_group_ids = 44; + repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; string cross_replica_sum_barrier = 46; @@ -160,6 +163,8 @@ message HloInstructionProto { // present for Send and Recv instructions and their SendDone and RecvDone // partners. bool is_host_transfer = 47; + + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 1f672502f72f9c658b681383e858995f6e94d2c7..1bbb0ff08e26f626f4c3992a5f20ec4990f7db2d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -49,9 +49,9 @@ Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) { // The default number of bytes accessed for an instruction is the sum of the // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not // handle opaque types. - float bytes_accessed = shape_size_(hlo->shape()); + float bytes_accessed = GetShapeSize(hlo->shape()); for (const HloInstruction* operand : hlo->operands()) { - bytes_accessed += shape_size_(operand->shape()); + bytes_accessed += GetShapeSize(operand->shape()); } current_properties_[kBytesAccessedKey] = bytes_accessed; @@ -121,6 +121,13 @@ Status HloCostAnalysis::HandleElementwiseOp( } } +int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const { + if (!LayoutUtil::HasLayout(shape)) { + return 0; + } + return shape_size_(shape); +} + Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } @@ -181,21 +188,21 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { } Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { - current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; + current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2; return Status::OK(); } Status HloCostAnalysis::HandleDynamicSlice( const HloInstruction* dynamic_slice) { current_properties_[kBytesAccessedKey] = - shape_size_(dynamic_slice->shape()) * 2; + GetShapeSize(dynamic_slice->shape()) * 2; return Status::OK(); } Status HloCostAnalysis::HandleDynamicUpdateSlice( const HloInstruction* dynamic_update_slice) { current_properties_[kBytesAccessedKey] = - shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; + GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -204,7 +211,7 @@ Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) { // through them). The memory touched is then only the size of the output // index table of the tuple. - current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape()); + current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape()); return Status::OK(); } @@ -526,12 +533,25 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. double flops = 0.0; - ShapeUtil::ForEachSubshape( - crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { - flops += ShapeUtil::ElementsIn(subshape); - } - }); + ShapeUtil::ForEachSubshape(crs->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; + return Status::OK(); +} + +Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { + // TODO(b/110096724): Compute correct cost here. + double flops = 0.0; + ShapeUtil::ForEachSubshape(hlo->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); current_properties_[kFlopsKey] = flops; return Status::OK(); } @@ -546,15 +566,9 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { } Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { - // Compute the properties of the fused expression and attribute them to the - // fusion node. Use a dummy shape_size to avoid any errors from trying to - // calculate the size of a shape that does not have a layout, since nodes - // inside fusion nodes do not necessarily have a layout assigned. - ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; }; TF_ASSIGN_OR_RETURN( current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation(), - &shape_size)); + ProcessSubcomputation(fusion->fused_instructions_computation())); // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to @@ -563,11 +577,11 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { ShapeUtil::ForEachSubshape( fusion->shape(), [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { - current_properties_[kBytesAccessedKey] += shape_size_(subshape); + current_properties_[kBytesAccessedKey] += GetShapeSize(subshape); }); for (const HloInstruction* operand : fusion->operands()) { - current_properties_[kBytesAccessedKey] += shape_size_(operand->shape()); + current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape()); } return Status::OK(); @@ -648,6 +662,11 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { return Status::OK(); } +Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { + // TODO(b/32945756): Compute the properties of the sub-computation. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -685,11 +704,8 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { } StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation, const ShapeSizeFunction* shape_size) { - if (shape_size == nullptr) { - shape_size = &shape_size_; - } - HloCostAnalysis visitor(*shape_size, per_second_rates_); + HloComputation* computation) { + HloCostAnalysis visitor(shape_size_, per_second_rates_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.properties(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 82d650dc7b2a7fdd7c156d5fadcabd40f5535161..193a04bea0831de2b3aca19b17a445ad73e02e49 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleHostCompute(const HloInstruction* host_compute) override; @@ -104,6 +105,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; + Status HandleScatter(const HloInstruction* scatter) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -149,11 +151,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { const Properties& per_second_rates); // Returns the properties computed from visiting the computation rooted at the - // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null, - // otherwise uses shape_size_. - StatusOr ProcessSubcomputation( - HloComputation* computation, - const ShapeSizeFunction* shape_size = nullptr); + // given hlo. + StatusOr ProcessSubcomputation(HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); @@ -170,6 +169,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor { static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, const HloToProperties& hlo_to_properties); + // Decorates shape_size_ by returning 0 immediately if the shape does not have + // a layout. + int64 GetShapeSize(const Shape& shape) const; + // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 9fd0363f578f562aa30bd3dadfea2c251a722af8..2c854eea18642eb7cb081b4fdfe3bc83627e41ae 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index de1a32d8bd9217baabda4ab4b02bf28baebad531..bbfb0c253f583b633c4b2c34b2f068b563d3d9e0 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -1017,19 +1017,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kFusion) { + if (fusion_can_share_buffer_ != nullptr) { + return fusion_can_share_buffer_(user, operand); + } // Get the parameter associated with 'operand'; HloInstruction* fusion_param = user->fused_parameter(user->operand_index(operand)); const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - if (MultiDynamicSliceUseShareSameIndices(value.uses())) { - return true; - } - return false; + if (MultiDynamicSliceUseShareSameIndices(value.uses())) { + return true; } - const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || user->fusion_kind() == HloInstruction::FusionKind::kInput) { if (user->fused_expression_root()->opcode() == @@ -1039,13 +1037,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root at operand // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } else { - return AreTransitiveUsesElementwiseOrTuple(fusion_param); + if (value.uses().size() == 1) { + const HloUse& use = value.uses()[0]; + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } + return false; } - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + return AreTransitiveUsesElementwiseOrTuple(fusion_param); + } + if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. // Check if one operand of kAdd fused root is kDot or kConvolution. @@ -1066,11 +1068,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root (at operand // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; - } else if (fusion_can_share_buffer_ != nullptr && - fusion_can_share_buffer_(user, operand)) { - return true; + if (value.uses().size() == 1) { + const HloUse& use = value.uses()[0]; + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + return false; } } @@ -1081,6 +1084,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // Get all uses of value defined by 'operand' at 'operand_index'. const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 37bc2d2c9d2a0d0624917337b36c5d5f625c0991..4755c4a0cf8d268b1c47e596a14605eb2c60b36c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2232,6 +2232,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto sort = + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + Shape values_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + auto sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The buffer for the keys can be shared with the first tuple entry. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0})); + // The buffer for the values can be shared with the second tuple entry. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1})); + // Verify that the buffers are not shared with the "wrong" tuple entry. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0})); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); @@ -2323,7 +2365,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto make_cond = [this, &data_shape]() { + auto make_cond = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -2332,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - auto make_body = [this, &data_shape]() { + auto make_body = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Body"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index c804f4364f6d16d5b8112219ce884495200aa827..b9244b8e9e5f34e7ac5113c8eacb6f8243eea314 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -144,6 +144,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || + opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || opcode == HloOpcode::kConditional) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 5f575b24a1fb36c5384592028e0f1f6a8e9404b6..251109c89f1ae971cec95057d604c396d4f99522 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -2048,6 +2048,459 @@ ENTRY main { *Evaluate({operand.get(), gather_indices.get()}))); } +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({2, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), + ErrorSpec{0.1, 0.01})); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + std::unique_ptr expected = + LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // + {{-40, 40}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, + EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + std::unique_ptr expected = + LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); + std::unique_ptr expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10}}, {{20}}}); + std::unique_ptr expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *operand, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + std::unique_ptr expected = + LiteralUtil::CreateR1({10, 61, 32}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. TEST_P(HloEvaluatorTest, DoesCompareBF16) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index f5e477e1151ec913bf4b7c06ae6a54e797bfb6a2..4dc03fd06d6917ed74371aa9926a1059e502826d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1481,8 +1481,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Rank(arg->shape()) - dimensions.size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), + {&arg->shape(), &init_value->shape()}, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) @@ -1772,6 +1771,388 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + // Reshapes the scatter indices input to have a trailing degenerate `1` + // dimension if necessary. Hands over the ownership of the newly created + // literal (if there is one) to `reshaped_indices`. + StatusOr> ReshapedScatterIndices( + int64 index_vector_dim, const Literal& indices, + std::unique_ptr* reshaped_indices) { + if (indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(indices); + } + + std::vector new_shape(indices.shape().dimensions().begin(), + indices.shape().dimensions().end()); + new_shape.push_back(1); + TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); + return std::cref(**reshaped_indices); + } + + // Returns an ShapeUtil::IndexIterationSpace that iterates over the update + // scatter dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_scatter_dim = + !c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_scatter_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // Return an ShapeUtil::IndexIterationSpace that iterates over the update + // window dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_window_dim = + c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // This functor computes the contribution of scatter_indices to an input index + // corresponding to an update index. That is, given an update index I, it + // picks out the scatter indices in I and uses them to look up a scatter + // index, S, from the scatter indices tensor, and expands S into the input + // space according to scatter_dims_to_operand_dims. + // + // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateScatterIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateScatterIndexToInputIndex( + const ScatterDimensionNumbers* dim_numbers, const Shape& input_shape, + const Shape& updates_shape, const Literal* scatter_indices) + : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + update_dim_is_scatter_dims_.push_back( + !c_binary_search(dim_numbers_.update_window_dims(), i)); + } + + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + int64 index_of_input_dim_in_index_vector = + FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i); + if (index_of_input_dim_in_index_vector == + dim_numbers_.scatter_dims_to_operand_dims_size()) { + input_dim_value_to_index_vector_.push_back(-1); + } else { + input_dim_value_to_index_vector_.push_back( + index_of_input_dim_in_index_vector); + } + } + + index_vector_index_.resize(scatter_indices_.shape().dimensions_size()); + input_index_.resize(input_shape.dimensions_size()); + int64 index_vector_size = + scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + index_vector_.resize(index_vector_size); + } + + // Returns the contribution of scatter_indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // scatter input index, but: + // + // - Instead of allocating memory to represent the scatter input index on + // every invocation we reuse the same storage for the result + // (input_index_), mutating it in place. + // - Instead of allocating buffers for temporary values like + // index_vector_index_ and index_vector on every invocation, we reuse the + // same storage for all invocations. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()( + tensorflow::gtl::ArraySlice update_index) { + PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); + TF_RETURN_IF_ERROR(FetchIndexVector()); + PropagateIndexVectorToInputIndex(); + return tensorflow::gtl::ArraySlice(input_index_); + } + + private: + // Propagates the scatter index dimensions from the update index into + // index_vector_index_ by mutating index_vector_index_ in place. Does not + // update the dim_numbers.index_vector_dim() dimension -- that's the + // dimension we iterate over in FetchIndexVector. + void PropagateUpdateIndexScatterDimsToIndexVectorIndex( + tensorflow::gtl::ArraySlice update_index) { + int64 index_vector_index_i = 0; + for (int64 i = 0, e = update_index.size(); i < e; i++) { + if (!update_dim_is_scatter_dims_[i]) { + continue; + } + + if (index_vector_index_i == dim_numbers_.index_vector_dim()) { + index_vector_index_i++; + } + + index_vector_index_[index_vector_index_i++] = update_index[i]; + } + } + + // Populates index_vector_ by iterating over scatter_indices_ according to + // index_vector_index_. + Status FetchIndexVector() { + int64 index_vector_dim = dim_numbers_.index_vector_dim(); + for (int64 i = 0, e = index_vector_.size(); i < e; i++) { + index_vector_index_[index_vector_dim] = i; + TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64( + index_vector_index_)); + } + return Status::OK(); + } + + // Populates input_index_. + void PropagateIndexVectorToInputIndex() { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_index_vector_[i] != -1) { + input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the index vector. See + // PropagateIndexVectorToInputIndex. + std::vector input_dim_value_to_index_vector_; + + // update_dim_is_scatter_dims_[i] is true iff the update index i is a + // scatter dimension. + std::vector update_dim_is_scatter_dims_; + + // The buffer into which we construct an index into scatter_indices_ to + // fetch the index vector. + std::vector index_vector_index_; + + // The index vector fetched from scatter_indices_. + std::vector index_vector_; + + // The result computed by this functor. operator() returns an ArraySlice + // into this vector. + std::vector input_index_; + + const ScatterDimensionNumbers& dim_numbers_; + const Literal& scatter_indices_; + }; + + // This functor computes the contribution of the window indices in an update + // index to an input index. That is, given an update index I it picks out the + // update window indices in I and expands it into a window index into the + // input shape. + // + // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateWindowIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateWindowIndexToInputIndex( + const ScatterDimensionNumbers& dim_numbers, const Shape& input_shape, + const Shape& updates_shape) { + std::vector window_index_to_update_index; + int64 update_index_count = 0; + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.update_window_dims(), i)) { + window_index_to_update_index.push_back(update_index_count++); + } else { + update_index_count++; + } + } + + int64 window_dim_count = 0; + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_dim_value_to_update_index_.push_back(-1); + } else { + input_dim_value_to_update_index_.push_back( + window_index_to_update_index[window_dim_count++]); + } + } + + input_index_.resize(input_shape.dimensions_size()); + } + + // Returns the contribution of the window indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // window input index, but instead of allocating memory to represent the + // scatter input index on every invocation we reuse the same storage for the + // result (input_index_), mutating it in place. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()( + tensorflow::gtl::ArraySlice update_index) { + PropagateUpdateIndexWindowDimsToInputIndex(update_index); + return tensorflow::gtl::ArraySlice(input_index_); + } + + // Returns for a given 'input_dim' the corresponding update dimension index, + // or -1 if 'input_dim' is an elided window dimension. + int64 input_dim_value_to_update_index(int64 input_dim) { + return input_dim_value_to_update_index_[input_dim]; + } + + private: + // Propagates window dimensions from the update index to input_index_ by + // mutating input_index_ in place. + void PropagateUpdateIndexWindowDimsToInputIndex( + tensorflow::gtl::ArraySlice update_index) { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_update_index_[i] != -1) { + input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the update index. See + // PropagateUpdateIndexWindowDimsToInputIndex. + std::vector input_dim_value_to_update_index_; + + // The result computed by this functor. operator() returns an ArraySlice + // into this vector. + std::vector input_index_; + }; + + Status HandleScatter(HloInstruction* scatter) override { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + const Literal& operand = + parent_->GetEvaluatedLiteralFor(scatter->operand(0)); + std::unique_ptr reshaped_scatter_indices; + TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, + ReshapedScatterIndices(dim_numbers.index_vector_dim(), + parent_->GetEvaluatedLiteralFor( + scatter->operand(1)), + &reshaped_scatter_indices)); + const Literal& updates = + parent_->GetEvaluatedLiteralFor(scatter->operand(2)); + const Shape& updates_shape = updates.shape(); + const Shape& operand_shape = operand.shape(); + + ShapeUtil::IndexIterationSpace scatter_indices_iteration_space = + IterationSpaceForUpdateScatterIndices(updates_shape, dim_numbers); + ShapeUtil::IndexIterationSpace window_indices_iteration_space = + IterationSpaceForUpdateWindowIndices(updates_shape, dim_numbers); + + std::vector input_index(operand_shape.dimensions_size()); + std::vector update_index(updates_shape.dimensions_size()); + std::vector input_scatter_index_clamped( + operand_shape.dimensions_size()); + + UpdateScatterIndexToInputIndex update_scatter_index_to_input_index( + &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape, &scatter_indices); + UpdateWindowIndexToInputIndex update_window_index_to_input_index( + scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape); + + // Initialize the result with the operand. This makes it easier to handle + // the updates even when the indices are repeated. + std::unique_ptr result = operand.CloneToUnique(); + HloEvaluator embedded_evaluator; + auto scatter_inner_loop_body = + [&](tensorflow::gtl::ArraySlice update_window_index, + tensorflow::gtl::ArraySlice input_scatter_index, + tensorflow::gtl::ArraySlice update_scatter_index) + -> StatusOr { + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::ArraySlice input_window_index, + update_window_index_to_input_index(update_window_index)); + for (int i = 0, e = update_index.size(); i < e; i++) { + update_index[i] = update_scatter_index[i] + update_window_index[i]; + DCHECK_LT(update_index[i], updates_shape.dimensions(i)); + } + for (int i = 0, e = input_scatter_index.size(); i < e; i++) { + int64 update_dim = + update_window_index_to_input_index.input_dim_value_to_update_index( + i); + // If 'update_dim' is -1, it means 'i' is an elided window dim. This + // means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the update dimension size to + // be 1. + int64 update_dim_size = + update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); + // Clamp the scatter index so that the scatter region fits in the + // operand. input_scatter_index_clamped[i] = + // clamp(input_scatter_index[i], 0, + // operand_shape.dimensions(i) - + // update_dim_size); + input_scatter_index_clamped[i] = + std::min(operand_shape.dimensions(i) - update_dim_size, + std::max(0LL, input_scatter_index[i])); + } + for (int i = 0, e = input_index.size(); i < e; i++) { + input_index[i] = input_scatter_index_clamped[i] + input_window_index[i]; + DCHECK_GE(input_index[i], 0); + DCHECK_LT(input_index[i], operand_shape.dimensions(i)); + } + + auto result_value_literal = + LiteralUtil::CreateR0(result->Get(input_index)); + auto update_value_literal = + LiteralUtil::CreateR0(updates.Get(update_index)); + std::unique_ptr updated_result = + embedded_evaluator + .Evaluate( + *scatter->to_apply(), + {result_value_literal.get(), update_value_literal.get()}) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on the + // same computation. + embedded_evaluator.ResetVisitStates(); + result->Set(input_index, updated_result->Get({})); + return true; + }; + + auto scatter_outer_loop_body = + [&](tensorflow::gtl::ArraySlice update_scatter_index) + -> StatusOr { + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::ArraySlice input_scatter_index, + update_scatter_index_to_input_index(update_scatter_index)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, window_indices_iteration_space, + [&](tensorflow::gtl::ArraySlice update_window_index) { + return scatter_inner_loop_body( + update_window_index, input_scatter_index, update_scatter_index); + })); + return true; + }; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, scatter_indices_iteration_space, + scatter_outer_loop_body)); + parent_->evaluated_[scatter] = std::move(result); + return Status::OK(); + } + Status HandleSlice(HloInstruction* slice) override { auto operand = slice->operand(0); const Shape& shape = slice->shape(); @@ -2078,10 +2459,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { start_indices_typed.end()); // Clamp the start indices so the slice is in-bounds w.r.t the operand. - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to officially document different behavior. for (int64 i = 0; i < start.size(); ++i) { start[i] = std::min( std::max(int64{0}, start[i]), @@ -2115,10 +2492,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. for (int64 i = 0; i < rank; ++i) { start[i] = std::min( std::max(0, start[i]), diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index fd5085bed234068a1bdf18977b38d92badc02a49..1efa6eb5bda7e1cb90874e0466aafd2c788a3fbf 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -844,7 +844,10 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8) { + // Allow HloDotDumper to print HloInstruction reconstructed from HloProto + // collected from profiling tools. Those constants may not have a valid + // literal. + if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -1019,6 +1022,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kScatter: + // Do not de-emphasize Scatter, since it involves significant work. case HloOpcode::kCopy: // Emphasize copy nodes, which are either physical transposes (and thus // significant), or copies of read-only buffers (and thus dead weight). @@ -1043,6 +1048,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: return kGray; case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8b9bdd2f46fe8a63b419b45ef2c2a2e025c60c8f..8690f2cdaa9b45d126e91b123c6992cbe2f27e1d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -320,6 +320,15 @@ StatusOr> HloInstruction::CreateFromProto( /*all_reduce_id=*/all_reduce_id); break; } + case HloOpcode::kAllToAll: { + instruction = CreateAllToAll( + proto.shape(), all_operands(), + /*replica_groups=*/ + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), + /*barrier=*/proto.cross_replica_sum_barrier()); + break; + } case HloOpcode::kConvolution: TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " @@ -404,6 +413,22 @@ StatusOr> HloInstruction::CreateFromProto( *gather_dimension_numbers, gather_window_bounds); break; } + case HloOpcode::kScatter: { + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "Scatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_scatter_dimension_numbers()) + << "Scatter instruction should have ScatterDimensionNumbers set."; + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Scatter instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + auto scatter_dimension_numbers = MakeUnique( + proto.scatter_dimension_numbers()); + instruction = + CreateScatter(proto.shape(), operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers); + break; + } default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -655,6 +680,14 @@ HloInstruction::CreateCrossReplicaSum( all_reduce_id); } +/* static */ std::unique_ptr HloInstruction::CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier) { + return MakeUnique(shape, operands, replica_groups, + barrier); +} + /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -812,11 +845,25 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateReduce( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) { + auto instruction = WrapUnique(new HloReduceInstruction( + shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); + return std::move(instruction); +} + +/* static */ std::unique_ptr HloInstruction::CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice init_values, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - return MakeUnique( - shape, arg, init_value, dimensions_to_reduce, reduce_computation); + std::vector all_args; + all_args.reserve(operands.size() * 2); + all_args.insert(all_args.end(), operands.begin(), operands.end()); + all_args.insert(all_args.end(), init_values.begin(), init_values.end()); + return MakeUnique(shape, all_args, dimensions_to_reduce, + reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( @@ -1062,6 +1109,16 @@ bool HloInstruction::HasSideEffect() const { gather_dim_numbers, window_bounds); } +/* static */ std::unique_ptr HloInstruction::CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) { + return MakeUnique(shape, operand, scatter_indices, + updates, update_computation, + scatter_dim_numbers); +} + /* static */ std::unique_ptr HloInstruction::CreateDomain( const Shape& shape, HloInstruction* operand, std::unique_ptr operand_side_metadata, @@ -1113,6 +1170,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1124,6 +1182,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDynamicSlice: case HloOpcode::kSort: case HloOpcode::kGather: + case HloOpcode::kScatter: case HloOpcode::kIota: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; @@ -1579,6 +1638,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -1587,6 +1647,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: + case HloOpcode::kScatter: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1693,6 +1754,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1711,6 +1773,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -1977,7 +2040,8 @@ std::vector HloInstruction::ExtraAttributesToString( } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || - opcode() == HloOpcode::kCrossReplicaSum) { + opcode() == HloOpcode::kCrossReplicaSum || + opcode() == HloOpcode::kScatter) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2013,6 +2077,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2219,6 +2284,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllToAll: + return visitor->HandleAllToAll(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2311,6 +2378,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kScatter: + return visitor->HandleScatter(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); case HloOpcode::kAfterAll: @@ -3091,12 +3160,23 @@ const std::vector& HloInstruction::replica_group_ids() const { return Cast(this)->replica_group_ids(); } +const std::vector& HloInstruction::replica_groups() const { + return Cast(this)->replica_groups(); +} + string HloInstruction::cross_replica_sum_barrier() const { - return Cast(this)->cross_replica_sum_barrier(); + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast(this)->cross_replica_sum_barrier(); + } + return Cast(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast(this)->set_cross_replica_sum_barrier( + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast(this)->set_cross_replica_sum_barrier( + barrier); + } + return Cast(this)->set_cross_replica_sum_barrier( barrier); } @@ -3171,4 +3251,9 @@ tensorflow::gtl::ArraySlice HloInstruction::gather_window_bounds() return Cast(this)->gather_window_bounds(); } +const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() + const { + return Cast(this)->scatter_dimension_numbers(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 30bff286c20033ec193fb29d8c4c935ce6475a27..3c575ae6ea8e60f48def4debcd9cfbea63e396b2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -447,8 +447,27 @@ class HloInstruction { HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id = - tensorflow::gtl::nullopt); + const tensorflow::gtl::optional& all_reduce_id); + + // This op handles the communication of an Alltoall operation. On each core, + // the operands are N ops in the same shape, where N is the number of cores + // participating the Alltoall. Then the N operands are scattered to N cores, + // e.g., the ith operand is sent to the ith core. Then each core gathers the + // received data into a tuple. + // + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall + // will be applied within subgroups in the specified order. For example, + // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied + // within replica 1, 2, 3, and in the gather phase, the received blocks will + // be concatenated in the order of 1, 2, 3; another Alltoall will be applied + // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. + // + // TODO(b/110096724): This is NOT YET ready to use. + static std::unique_ptr CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -542,17 +561,34 @@ class HloInstruction { int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) - // is applied successively to every element in operand. That is, if f is the - // function to apply (which either takes 2 [accumulator, value] or 3 - // [accumulator, index, value] arguments) and init is a reduction operator - // specified initial value (for example, 0 for addition), then this operation - // will compute: - // f(f(init, [index0], value0), [index1], value1), ...) + // is applied successively to every element in operand. For example, let f be + // the function to apply, which takes 2 arguments, an accumulator and the + // current value. Let init be an initial value (which is normally chosen to be + // the identity element for f, e.g. 0 if f is addition). + // Then the reduce HLO will compute: + // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); + // A more general, multiple-argument version of the above. + // The function to apply, f, now takes N arguments: + // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., + // init_valueN], and returns an N-tuple. The performed computation is (for + // commutative and associative f operators) equivalent to: + // + // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) + // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, + // ..., inputN.value1) + // ... + // TODO(b/112040122): Add support to this in HLO passes and in backends. + static std::unique_ptr CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice init_values, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Creates a reduce-window instruction, where the computation (given // by the handle) is applied window-wise at each valid window // position in the operand. @@ -645,6 +681,12 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + static std::unique_ptr CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + // Creates a kDomain instruction which delimits an HLO domain which have // the provided user and operand side metadata. static std::unique_ptr CreateDomain( @@ -1015,9 +1057,7 @@ class HloInstruction { if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - auto device = sharding_->UniqueDevice(); - return device.ok() ? device.ValueOrDie() - : tensorflow::gtl::optional(); + return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. @@ -1394,6 +1434,9 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::replica_group_ids. const std::vector& replica_group_ids() const; + // Delegates to HloAllToAllInstruction::replica_groups. + const std::vector& replica_groups() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); @@ -1455,6 +1498,9 @@ class HloInstruction { // Delegates to HloGatherInstruction::gather_window_bounds. tensorflow::gtl::ArraySlice gather_window_bounds() const; + // Delegates to HloScatterInstruction::scatter_dimension_numbers(). + const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Old methods kept for smooth subclassing transition END. protected: diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index b75a2bd34bc5d3b5b6100515748df787b9e7f08a..8a694dde8066ab9a1138b9f7981153d451ddb89e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1425,6 +1425,55 @@ TEST_F(HloInstructionTest, StringifyGather_1) { "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } +TEST_F(HloInstructionTest, StringifyScatter) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape scatter_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape scatter_updates_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Scatter"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* scatter_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, scatter_indices_tensor_shape, "scatter_indices")); + HloInstruction* scatter_updates = + builder.AddInstruction(HloInstruction::CreateParameter( + 2, scatter_updates_shape, "scatter_updates")); + + HloComputation::Builder update_builder("Scatter.update"); + update_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1")); + update_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); + + auto module = CreateNewModule(); + auto* update_computation = + module->AddEmbeddedComputation(update_builder.Build()); + + HloInstruction* scatter_instruction = + builder.AddInstruction(HloInstruction::CreateScatter( + input_tensor_shape, input, scatter_indices, scatter_updates, + update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + scatter_instruction->ToString(), + "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} " + "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, " + "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), " + "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, " + "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, " + "to_apply=%Scatter.update"); +} + TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { // Tests stringification of a simple op, fusion, while, and conditional. const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index df26a2c744fbcac814727139e1cf7f23037dcc50..1de5032670ff47cda5599cf736bbd3529cfcaba9 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -359,6 +359,67 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( cross_replica_sum_barrier(), all_reduce_id()); } +HloAllToAllInstruction::HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier) + : HloInstruction(HloOpcode::kAllToAll, shape), + replica_groups_(replica_groups), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +bool HloAllToAllInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier(); +} + +std::unique_ptr +HloAllToAllInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return MakeUnique( + shape, new_operands, replica_groups(), cross_replica_sum_barrier()); +} + +std::vector HloAllToAllInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result; + std::vector replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", Join(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + + return result; +} + +HloInstructionProto HloAllToAllInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + HloReverseInstruction::HloReverseInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) @@ -438,13 +499,14 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl( } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice args, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { - AppendOperand(arg); - AppendOperand(init_value); + for (HloInstruction* arg : args) { + AppendOperand(arg); + } AppendComputation(reduce_computation); } @@ -477,8 +539,8 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( - shape, new_operands[0], new_operands[1], dimensions(), to_apply()); + return MakeUnique(shape, new_operands, dimensions(), + to_apply()); } HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, @@ -2015,4 +2077,91 @@ std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( gather_window_bounds()); } +HloScatterInstruction::HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) + : HloInstruction(HloOpcode::kScatter, shape) { + AppendOperand(operand); + AppendOperand(scatter_indices); + AppendOperand(updates); + AppendComputation(update_computation); + scatter_dimension_numbers_ = + MakeUnique(scatter_dim_numbers); +} + +string HloScatterInstruction::ScatterDimensionNumbersToString() const { + string update_window_dims = + StrCat("update_window_dims={", + Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string inserted_window_dims = StrCat( + "inserted_window_dims={", + Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + string scatter_dims_to_operand_dims = StrCat( + "scatter_dims_to_operand_dims={", + Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); + + return Join>( + {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ ScatterDimensionNumbers +HloScatterInstruction::MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim) { + ScatterDimensionNumbers scatter_dim_numbers; + for (int64 update_window_dim : update_window_dims) { + scatter_dim_numbers.add_update_window_dims(update_window_dim); + } + for (int64 inserted_window_dim : inserted_window_dims) { + scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim); + } + for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) { + scatter_dim_numbers.add_scatter_dims_to_operand_dims( + scatter_dim_to_operand_dim); + } + scatter_dim_numbers.set_index_vector_dim(index_vector_dim); + return scatter_dim_numbers; +} + +HloInstructionProto HloScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); + return proto; +} + +std::vector HloScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {ScatterDimensionNumbersToString()}; +} + +bool HloScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals( + scatter_dimension_numbers(), + casted_other.scatter_dimension_numbers()) && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), + scatter_dimension_numbers()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e4031f04d5c0062d73efb2c8f95b462b691407fa..9586ad667345111d05015e035c93fe6578e3b665 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -224,8 +224,7 @@ class HloAllReduceInstruction : public HloInstruction { HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id = - tensorflow::gtl::nullopt); + const tensorflow::gtl::optional& all_reduce_id); // Returns the group ids of each replica for CrossReplicaSum op. const std::vector& replica_group_ids() const { @@ -274,6 +273,47 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::gtl::optional all_reduce_id_; }; +class HloAllToAllInstruction : public HloInstruction { + public: + explicit HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operand, + const std::vector& replica_groups, + tensorflow::StringPiece barrier); + + const std::vector& replica_groups() const { + return replica_groups_; + } + + // TODO(b/110096724): rename this. + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector replica_groups_; + + // The string representation of the barrier config. + string cross_replica_sum_barrier_; +}; + class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, @@ -332,7 +372,7 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: explicit HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice args, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. @@ -341,6 +381,18 @@ class HloReduceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the input tensors to be reduced. + tensorflow::gtl::ArraySlice inputs() const { + return tensorflow::gtl::ArraySlice(operands(), 0, + operand_count() / 2); + } + + // Returns the init values of the reduction. + tensorflow::gtl::ArraySlice init_values() const { + return tensorflow::gtl::ArraySlice( + operands(), operand_count() / 2, operand_count()); + } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -535,6 +587,8 @@ class HloConstantInstruction : public HloInstruction { explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } + // Returns whether there is literal associated with this instruction. + bool HasLiteral() const { return literal_ != nullptr; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1199,6 +1253,45 @@ class HloGatherInstruction : public HloInstruction { std::vector gather_window_bounds_; }; +class HloScatterInstruction : public HloInstruction { + public: + explicit HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + const ScatterDimensionNumbers& scatter_dimension_numbers() const { + CHECK(scatter_dimension_numbers_ != nullptr); + return *scatter_dimension_numbers_; + } + // Returns the dump string of the scatter dimension numbers. + string ScatterDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of ScatterDimensionNumbers. + static ScatterDimensionNumbers MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::unique_ptr scatter_dimension_numbers_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index f0d9fdbc8f86da0bb9d7f9235239df677c9506bc..71b44507cc704344ff6fe5269ea498bb32cfb8a6 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -299,9 +299,12 @@ TokKind HloLexer::LexNumberOrPattern() { static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strto64( - StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); - return TokKind::kInt; + auto slice = StringPieceFromPointers(token_start_, current_ptr_); + if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + return TokKind::kInt; + } + LOG(ERROR) << "Failed to parse int literal: " << slice; + return TokKind::kError; } static LazyRE2 neg_inf = {"-inf"}; diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 7de59acc1efbc0150b95ebdd85a13ede48eec2f9..7961aece541faeb66875885b380158756c503250 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -157,9 +157,8 @@ TEST(HloMatchersTest, ShardingMatcher) { Array assignment({2}); assignment.SetValues({0, 1}); auto sharding = HloSharding::Tuple( - tuple_shape, - {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), - HloSharding::AssignDevice(1), HloSharding::Replicate()}); + tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1), + HloSharding::Replicate()}); p2->set_sharding(sharding); EXPECT_THAT(p0.get(), op::NoSharding()); @@ -172,8 +171,7 @@ TEST(HloMatchersTest, ShardingMatcher) { EXPECT_THAT( p2.get(), - op::Sharding( - "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}")); EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 59e9a5a94aa4fc6270bde76c19dbd0d4506a563c..ec279867e595b66a22882703cc06046e3e916c96 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,7 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ V(kBatchNormInference, "batch-norm-inference") \ @@ -118,6 +119,7 @@ namespace xla { V(kReverse, "reverse") \ V(kRng, "rng") \ V(kRoundNearestAfz, "round-nearest-afz") \ + V(kScatter, "scatter") \ V(kSelect, "select") \ V(kSelectAndScatter, "select-and-scatter") \ V(kSend, "send") \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index e8eaf54949d6e41ebffabe7963cf737ce5ad4567..2a8c6ecd9248b9bf77153781d9c169306c9a9197 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -125,6 +125,7 @@ class HloParser { kFloat, kString, kBracedInt64List, + kBracedInt64ListList, kHloComputation, kFftType, kWindow, @@ -205,6 +206,10 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + // 'parse_and_add_item' is an lambda to parse an element in the list and add + // the parsed element to the result. It's supposed to capture the result. + bool ParseList(const TokKind start, const TokKind end, const TokKind delim, + const std::function& parse_and_add_item); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -619,6 +624,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } break; } + case HloOpcode::kAllToAll: { + optional>> tmp_groups; + optional barrier; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector replica_groups; + if (tmp_groups) { + c_transform(*tmp_groups, std::back_inserter(replica_groups), + [](const std::vector& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + } + instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( + shape, operands, replica_groups, barrier ? *barrier : "")); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -865,18 +892,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReduce: { + auto loc = lexer_.GetLoc(); + optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() % 2) { + return Error(loc, StrCat("expects an even number of operands, but has ", + operands.size(), " operands")); + } instruction = builder->AddInstruction(HloInstruction::CreateReduce( - shape, /*operand=*/operands[0], /*init_value=*/operands[1], + shape, /*operands=*/ + tensorflow::gtl::ArraySlice(operands, 0, + operands.size() / 2), + /*init_values=*/ + tensorflow::gtl::ArraySlice( + operands, operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; } @@ -1132,13 +1169,24 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional custom_call_target; + optional window; + optional dnums; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + attrs["dim_labels"] = {/*required=*/false, + AttrTy::kConvolutionDimensionNumbers, &dnums}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( shape, operands, *custom_call_target)); + if (window.has_value()) { + instruction->set_window(*window); + } + if (dnums.has_value()) { + instruction->set_convolution_dimension_numbers(*dnums); + } break; } case HloOpcode::kHostCompute: { @@ -1231,6 +1279,42 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dim_numbers, *window_bounds)); break; } + case HloOpcode::kScatter: { + optional> update_window_dims; + attrs["update_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims}; + optional> inserted_window_dims; + attrs["inserted_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims}; + optional> scatter_dims_to_operand_dims; + attrs["scatter_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &scatter_dims_to_operand_dims}; + optional index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + + optional update_computation; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &update_computation}; + + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + + ScatterDimensionNumbers dim_numbers = + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/*update_window_dims, + /*inserted_window_dims=*/*inserted_window_dims, + /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateScatter( + shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1], + /*updates=*/operands[2], *update_computation, dim_numbers)); + break; + } case HloOpcode::kDomain: { DomainData domain; attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; @@ -1326,7 +1410,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, bool replicated = false; std::vector devices; std::vector tile_assignment_dimensions; - Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1377,7 +1460,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, break; } case TokKind::kShape: - tile_shape = lexer_.GetShapeVal(); + // TODO(b/112302613): Left here for backward compatibility to ignore the + // removed tile shape data. lexer_.Lex(); break; case TokKind::kRbrace: @@ -1392,19 +1476,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error(loc, "replicated shardings should not have any devices assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, - "replicated shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return Error(loc, "maximal shardings should have exactly one device assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "maximal shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { @@ -1412,9 +1489,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error( loc, "non-maximal shardings must have more than one device assigned"); } - if (ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "non-maximal shardings should have a tile shape set"); - } if (tile_assignment_dimensions.empty()) { return Error( loc, @@ -1422,7 +1496,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding->mutable_tile_shape() = tile_shape; for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } @@ -1579,6 +1652,24 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, "value ", value, " is out of range for literal's primitive type ", PrimitiveType_Name(literal->shape().element_type()))); } + } else if (std::is_unsigned::value) { + CHECK((std::is_same::value || + std::is_same::value)) + << "Unimplemented checking for ParsedElemT"; + + ParsedElemT upper_bound; + if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { + upper_bound = std::numeric_limits::max(); + } else { + upper_bound = + static_cast(std::numeric_limits::max()); + } + if (value > upper_bound || value < 0) { + // Value is out of range for LiteralNativeT. + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); + } } else if (value > static_cast( std::numeric_limits::max()) || value < static_cast( @@ -2180,6 +2271,26 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kBracedInt64ListList: { + std::vector> result; + auto parse_and_add_item = [&]() { + std::vector item; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, + TokKind::kComma, &item)) { + return false; + } + result.push_back(item); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + static_cast>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kSliceRanges: { SliceRanges result; if (!ParseSliceRanges(&result)) { @@ -2522,6 +2633,26 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +bool HloParser::ParseList(const TokKind start, const TokKind end, + const TokKind delim, + const std::function& parse_and_add_item) { + if (!ParseToken(start, StrCat("expects a list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects a list to end with ", TokKindToString(end))); +} + // param_list_to_shape ::= param_list '->' shape bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 1f0572c576c5b22cb7827ff26197e816132ce62e..4cd21841f4c25071d222cd291ed56aad2d266ca7 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -758,6 +758,46 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5 ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} } +)" +}, +{ +"scatter", +R"(HloModule StringifyScatter + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3 +} + +)" +}, +{ + "ConstantUnsignedNoUnderflow", + R"(HloModule ConstantUnsignedNoUnderflow_module + +ENTRY %ConstantUnsignedNoUnderflow () -> u64[] { + ROOT %constant = u64[] constant(1) +} + +)" +}, + +{ + "ConstantUnsignedNoOverflow", + R"(HloModule ConstantUnsignedNoOverflow_module + +ENTRY %ConstantUnsignedNoOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775807) +} + )" }, }); @@ -803,6 +843,32 @@ ENTRY ReduceR3ToR2.v3 { ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 } +)" +}, +// tuple reduce +{ +"TupleReduce", +R"(HloModule TupleReduce + +max_argmax { + value = f32[] parameter(2) + prev_max = f32[] parameter(0) + is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + max = f32[] select(is_next_larger, value, prev_max) + index = s32[] parameter(3) + prev_argmax = s32[] parameter(1) + argmax = s32[] select(is_next_larger, index, prev_argmax) + ROOT pair = (f32[], s32[]) tuple(max, argmax) +} + +ENTRY reduce_entry { + values = f32[1024]{0} parameter(0) + indices = f32[1024]{0} parameter(1) + init_value = f32[] constant(-inf) + init_index = s32[] constant(-1) + ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax +} + )" }, // infeed/outfeed @@ -1004,6 +1070,30 @@ ENTRY CrossReplicaSumWithSubgroups { ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add } +)" +}, +// all-to-all +{ +"AllToAll", +R"(HloModule AllToAll + +ENTRY AllToAll { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={} +} + +)" +}, +// all-to-all with subgroups +{ +"AllToAllWithSubgroups", +R"(HloModule AllToAllWithSubgroups + +ENTRY AllToAllWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc" +} + )" }, // Iota @@ -1015,6 +1105,17 @@ ENTRY Iota { ROOT iota = f32[100]{0} iota() } +)" +}, +// custom-call with window and dim_labels +{ +"CustomCallWithWindowAndDimLabels", +R"(HloModule CustomCallWithWindowAndDimLabels + +ENTRY Computation { + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" +} + )" } }); @@ -1213,6 +1314,40 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { "is out of range for literal's primitive type F16"); } +TEST_F(HloParserTest, ConstantUnsignedUnderflow) { + const string original = R"( + HloModule ConstantUnsignedUnderflow_module + ENTRY %ConstantUnsignedUnderflow () -> u64[] { + ROOT %constant = u64[] constant(-1) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U64"); +} + +TEST_F(HloParserTest, ConstantUnsignedOverflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u32[] { + ROOT %constant = u32[] constant(4294967296) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U32"); +} + +TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775808) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index b3d0a07add39968c6310392ea01daeab8a7dd9af..28194deb0e32252b372a328b006dabaf250fa2c7 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ +#include + #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,9 +36,19 @@ class HloPassFix : public Pass { StatusOr Run(HloModule* module) override { bool changed = false; bool changed_this_iteration = true; + int64 iteration_count = 0; + int64 limit = + std::max(static_cast(1000), module->instruction_count()); while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; + ++iteration_count; + if (iteration_count == limit) { + LOG(ERROR) + << "Unexpectedly number of iterations in HLO passes (" + << iteration_count + << ")\nIf compilation hangs here, please file a bug with XLA."; + } } return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index cf9ceed5b2fb49eb91fea96d89c8e1efc2a3dad1..9ec983c2bc353955cb23d441d200ac8aa36951b1 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -282,7 +282,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, ScheduleComputationsInModule(*module, - [&TUPLE_SIZE](const BufferValue& buffer) { + [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf( buffer.shape(), TUPLE_SIZE); }, diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 393944c20faa0b09ebc8544543b62566c836739f..879fb3bbab2ada0f924282f16b3d9ccb4c2cb203 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -31,12 +31,9 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { CHECK_EQ(1, ShapeUtil::Rank(input_shape)); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); - Shape tile_shape = input_shape; - auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; - tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); Array assignment(dimensions); std::iota(assignment.begin(), assignment.end(), 0); - return HloSharding(tile_shape, assignment); + return HloSharding(assignment); } HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { @@ -104,8 +101,7 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", - Join(tile_assignment_.dimensions(), ","), "]", + return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", Join(tile_assignment_, ","), "}"); } } @@ -127,15 +123,15 @@ std::map HloSharding::UsedDevices(int64* count) const { if (IsTuple()) { for (auto& tuple_element_sharding : tuple_elements()) { auto unique_device = tuple_element_sharding.UniqueDevice(); - if (unique_device.ok()) { - device_map[unique_device.ValueOrDie()] += 1; + if (unique_device) { + device_map[*unique_device] += 1; } } element_count = tuple_elements().size(); } else { auto unique_device = UniqueDevice(); - if (unique_device.ok()) { - device_map[unique_device.ValueOrDie()] += 1; + if (unique_device) { + device_map[*unique_device] += 1; } } if (count != nullptr) { @@ -145,7 +141,6 @@ std::map HloSharding::UsedDevices(int64* count) const { } std::vector HloSharding::TileIndexForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); CHECK(!IsTuple()); std::vector ret_index; @@ -165,32 +160,43 @@ int64 HloSharding::DeviceForTileIndex( if (maximal_) { return *tile_assignment_.begin(); } - CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size()); return tile_assignment_(index); } -std::vector HloSharding::TileOffsetForDevice(int64 device) const { +std::vector HloSharding::TileOffsetForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - std::vector index = TileIndexForDevice(device); if (maximal_) { - // Index will always be all zeroes if we're maximal, and tile_shape_ is not - // valid. - return index; + return std::vector(shape.dimensions_size(), 0); } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] *= tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim); } return index; } -std::vector HloSharding::TileLimitForDevice(int64 device) const { +std::vector HloSharding::TileLimitForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. + if (maximal_) { + return std::vector(shape.dimensions().begin(), + shape.dimensions().end()); + } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] = (index[i] + 1) * tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), + shape_dim); } return index; } @@ -238,40 +244,31 @@ StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { return Tuple(ShapeTree(shape, *this)); } -StatusOr HloSharding::UniqueDevice() const { +tensorflow::gtl::optional HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on empty tuple"); + return tensorflow::gtl::nullopt; } - std::vector> results; - std::transform(tuple_elements_.begin(), tuple_elements_.end(), - std::back_inserter(results), - [](const HloSharding& s) { return s.UniqueDevice(); }); - if (std::all_of(results.begin(), results.end(), - [&](const StatusOr& s) { - return s.ok() && results[0].ok() && - s.ValueOrDie() == results[0].ValueOrDie(); - })) { - return results[0]; - } else { - return tensorflow::errors::InvalidArgument( - "Tuple did not contain a unique device"); + tensorflow::gtl::optional unique_device; + for (auto& tuple_sharding : tuple_elements_) { + auto device = tuple_sharding.UniqueDevice(); + if (!device || (unique_device && *device != *unique_device)) { + return tensorflow::gtl::nullopt; + } + unique_device = device; } + return unique_device; } - if (!replicated_ && maximal_ && !IsTuple()) { + if (!replicated_ && maximal_) { return static_cast(*tile_assignment_.begin()); } - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on sharding that executes on multiple devices"); + return tensorflow::gtl::nullopt; } -bool HloSharding::HasUniqueDevice() const { - if (IsTuple()) { - return UniqueDevice().status().ok(); - } else { - return !IsReplicated() && IsTileMaximal(); - } +int64 HloSharding::GetUniqueDevice() const { + auto device = UniqueDevice(); + CHECK(device) << "Sharding does not have a unique device: " << *this; + return *device; } Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { @@ -345,11 +342,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Status::OK(); } - // The tile rank must be the same as the input rank. - if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { + // The tile assignment tensor must have the same rank as the input. + if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank. sharding=", ToString(), - ", input_shape=", ShapeUtil::HumanString(shape)); + "Number of tile assignment dimensions is different to the input rank. " + "sharding=", + ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } // The correct constructor have to be used to create tile maximal shardings. @@ -359,20 +357,6 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, "sharding was intended, use HloSharding::Replicated(). If a device " "placement was intended, use HloSharding::AssignDevice()"); } - - // The tile assignment tensor must contain enough element to cover the full - // shape with tiles of the specified size. - for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); - if (shape.dimensions(i) > total_tile_size) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has too few element to cover the full " - "shape. Dimension ", - i, ", shape ", shape.dimensions(i), ", total size ", - total_tile_size)); - } - } - return Status::OK(); } @@ -402,7 +386,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, proto.tile_assignment_dimensions().end())); std::copy(proto.tile_assignment_devices().begin(), proto.tile_assignment_devices().end(), tile_assignment.begin()); - return HloSharding(proto.tile_shape(), tile_assignment); + return HloSharding(tile_assignment); } OpSharding HloSharding::ToProto() const { @@ -416,7 +400,6 @@ OpSharding HloSharding::ToProto() const { return result; } - *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -433,30 +416,16 @@ OpSharding HloSharding::ToProto() const { return result; } -HloSharding HloSharding::TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform) const { - CHECK(!IsTuple()); +Shape HloSharding::TileShape(const Shape& shape) const { if (IsTileMaximal()) { - return *this; + return shape; } - CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); - Shape new_tile_shape; - new_tile_shape.set_element_type(tile_shape().element_type()); - for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { - int64 dim; - if (tile_assignment().dim(i) == 1) { - dim = new_shape.dimensions(i); - } else if (transform) { - dim = transform(i, tile_shape().dimensions(i)); - } else { - dim = tile_shape().dimensions(i); - } - new_tile_shape.add_dimensions(dim); + Shape result_shape = shape; + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + (*result_shape.mutable_dimensions())[i] = + CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i)); } - TF_CHECK_OK( - LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); - return HloSharding::Tile(new_tile_shape, tile_assignment()); + return result_shape; } HloSharding HloSharding::GetSubSharding(const Shape& shape, @@ -498,9 +467,6 @@ size_t HloSharding::Hash() const { for (uint32 v : tile_assignment_) { h = tensorflow::Hash64Combine(h, std::hash{}(v)); } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } return h; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 6f672b0f28d2b85411d70f33da9a9f270aefc0d0..894783e5d1538fa4e8e91b65827121f32040af83 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -48,22 +48,10 @@ class HloSharding { // the input shape (one tile) assigned to a single device. static HloSharding AssignDevice(int64 device_id); - // Creates a new sharding which splits a shape into tiles each with shape - // `tile_shape`. Each tile is assigned to one device, which is specified by - // `tile_assignment`. Any tensor not a multiple of the tile size in any - // dimension is implicitly padded to the tile size. - // - // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like: - // 2 1 padding - // <------><-> - // +----+----+ - // | 0 | 1 | - // +----+----+ - // - // Split into two tiles, one of which is implicitly padded by one. - static HloSharding Tile(const Shape& tile_shape, - const Array& tile_assignment) { - return HloSharding(tile_shape, tile_assignment); + // Creates a new sharding which splits a shape into tiles amongst the devices + // specified by `tile_assignment`. + static HloSharding Tile(const Array& tile_assignment) { + return HloSharding(tile_assignment); } // Creates a new sharding which splits a one-dimensional input shape into @@ -146,24 +134,30 @@ class HloSharding { // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; - // Given a device ID, returns the offset within the input space of the + // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileOffsetForDevice(int64 device) const; + std::vector TileOffsetForDevice(const Shape& shape, + int64 device) const; - // Given a device ID, returns the limit within the input space of the + // Given a device ID, returns the limit within the specified shape of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileLimitForDevice(int64 device) const; + std::vector TileLimitForDevice(const Shape& shape, int64 device) const; - // Returns the single device this op operates on. - // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() - StatusOr UniqueDevice() const; + // Returns the single device this op operates on. If the sharding does not + // span a single device, the return value will be empty. + // In order for a sharding to span a single device, every leaf sharding must + // be maximal and not replicated, and the used device must match. + tensorflow::gtl::optional UniqueDevice() const; + + // Retrieves the unique device or fails with a CHECK. + int64 GetUniqueDevice() const; // Returns true if this op only uses a single device. - bool HasUniqueDevice() const; + bool HasUniqueDevice() const { return UniqueDevice().has_value(); } // Returns the ShapeTree containing the shardings for each element of this // tuple, if IsTuple, or a ShapeTree with a single element containing this @@ -192,7 +186,6 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -206,9 +199,6 @@ class HloSharding { } }; - // Gets the tile shape. - // REQUIRES: !IsTileMaximal() && !IsTuple() - const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } @@ -220,25 +210,15 @@ class HloSharding { return tuple_elements_; } - // Return a new sharding that can apply to the given new shape. - // If this sharding is tile-maximal, the returned sharding will be the same as - // this sharding. If this sharding is not tile-maximal, the returned - // sharding's tile size will differ: - // - Non-sharded dimensions will be adapted to be the same as `new_shape`; - // tile_dimension(i) = new_shape.dimensions(i); - // - Sharded dimensions will be kept the same unless `transform` is supplied - // in which case tile_dimension(i) = transform(i, tile_dimension(i)); - // REQUIRES: !IsTuple(). - HloSharding TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform = nullptr) const; + // Gets the tile shape. + // REQUIRES: !IsTuple() + Shape TileShape(const Shape& shape) const; private: HloSharding() : replicated_(true), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({0}) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning @@ -250,15 +230,13 @@ class HloSharding { : replicated_(false), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({1}, device_id) {} - HloSharding(const Shape& tile_shape, const Array& tile_assignment) + explicit HloSharding(const Array& tile_assignment) : replicated_(false), maximal_(false), tuple_(false), - tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} - HloSharding(const std::vector& tuple_shardings) + explicit HloSharding(const std::vector& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), @@ -281,7 +259,6 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; - Shape tile_shape_; Array tile_assignment_; // Only non-empty when tuple_ is true, but because empty tuples are allowed // may also be empty even then. This is a flattened list of all the leaf diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 7baa927d0e2b1abbbb2333633d16dd605ae8c8ef..45fc300fcaf5a301fe11768da77a7c0907919c39 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -39,7 +39,6 @@ Array MakeArray(tensorflow::gtl::ArraySlice dimensions, class HloShardingTest : public HloTestBase {}; TEST_F(HloShardingTest, Replicate) { - Shape tile_shape = ShapeUtil::MakeShape(U32, {4}); HloSharding sharding = HloSharding::Replicate(); EXPECT_TRUE(sharding.IsReplicated()); EXPECT_TRUE(sharding.IsTileMaximal()); @@ -51,7 +50,7 @@ TEST_F(HloShardingTest, Replicate) { EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/2)); - EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + EXPECT_FALSE(sharding.HasUniqueDevice()); } TEST_F(HloShardingTest, DevicePlacement) { @@ -60,7 +59,7 @@ TEST_F(HloShardingTest, DevicePlacement) { EXPECT_TRUE(sharding.IsTileMaximal()); EXPECT_FALSE(sharding.UsesDevice(0)); EXPECT_TRUE(sharding.UsesDevice(5)); - EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie()); + EXPECT_EQ(5, sharding.GetUniqueDevice()); HloSharding other = HloSharding::Replicate(); EXPECT_NE(other, sharding); @@ -79,37 +78,22 @@ TEST_F(HloShardingTest, DevicePlacement) { TEST_F(HloShardingTest, Tile) { { // Test should fail because of a duplicate tile assignment. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}), /*num_devices=*/4)); } { // Test should fail because of more devices used then `num_device`. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); } - { - // Test should fail because the total tiled size in dimension 0 is 4 but we - // have 6 elements along that dimensions. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}), - /*num_devices=*/4)); - } - { // Test should pass. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + Shape shape = ShapeUtil::MakeShape(U32, {4, 5}); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}), /*num_devices=*/5)); @@ -118,12 +102,16 @@ TEST_F(HloShardingTest, Tile) { EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0})); EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1})); - EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector{0, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector{0, 3})); - EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector{2, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector{2, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0), + (std::vector{0, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3), + (std::vector{0, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2), + (std::vector{2, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1), + (std::vector{2, 3})); - EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + EXPECT_FALSE(sharding.HasUniqueDevice()); } } @@ -135,8 +123,7 @@ TEST_F(HloShardingTest, NestedTuple) { ShapeUtil::MakeShape(F32, {4, 6}), }); - HloSharding tiled_sharding = HloSharding::Tile( - ShapeUtil::MakeShape(F32, {4, 3}), Array({{0, 1}})); + HloSharding tiled_sharding = HloSharding::Tile(Array({{0, 1}})); OpSharding proto; proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); @@ -187,32 +174,11 @@ TEST_F(HloShardingTest, Hash) { } { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); - EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); - } - - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); } - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 1, 2})); - EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); - } - HloSharding default_sharding = HloSharding::Replicate(); { ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), @@ -259,19 +225,6 @@ TEST_F(HloShardingTest, Hash) { } } -TEST_F(HloShardingTest, TransformShardedTileShapeTest) { - HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); - HloSharding result = sharding.TransformShardedTileShape( - ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), - [](int dim, int value) { return dim * 111; }); - HloSharding expected = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), - Array4D({{{{0, 1}, {2, 3}}}})); - EXPECT_EQ(result, expected); -} - TEST_F(HloShardingTest, ToStringReplicatedTest) { HloSharding sharding = HloSharding::Replicate(); EXPECT_EQ(sharding.ToString(), "{replicated}"); @@ -284,9 +237,8 @@ TEST_F(HloShardingTest, ToStringAssignDeviceTest) { TEST_F(HloShardingTest, ToStringTiledTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), - Array3D({{{2, 3}}, {{5, 7}}})); - EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); + HloSharding::Tile(Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}"); } TEST_F(HloShardingTest, ToStringTupleTest) { @@ -294,21 +246,18 @@ TEST_F(HloShardingTest, ToStringTupleTest) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {7, 25}), ShapeUtil::MakeShape(S32, {9, 11})}), - {HloSharding::Replicate(), - HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), - Array2D({{3, 5}})), + {HloSharding::Replicate(), HloSharding::Tile(Array2D({{3, 5}})), HloSharding::AssignDevice(3)}); EXPECT_EQ(sharding.ToString(), - "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); + "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}"); } TEST_F(HloShardingTest, OstreamTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding::Tile(Array4D({{{{0, 1}, {2, 3}}}})); std::ostringstream oss; oss << sharding; - EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); + EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}"); } TEST_F(HloShardingTest, ParseHloString) { @@ -319,8 +268,7 @@ TEST_F(HloShardingTest, ParseHloString) { }; check(HloSharding::Replicate()); check(HloSharding::AssignDevice(2)); - check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}}))); + check(HloSharding::Tile(Array4D({{{{0}, {1}}}}))); // Empty tuple. One sharding is required for empty tuples, as we need to be // able to assign sharding to them, even though they have no leaves. check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), @@ -332,8 +280,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})}); check(HloSharding::Tuple( - tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + tuple_shape, {HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)})); } { @@ -343,8 +290,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})})}); std::vector leaf_shardings = { - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)}; ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); // Assign leaf_shardings to sharding_tree leaves. diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 48f676db85ab5e7711d9e9ac900306a9ea85ef10..b78bfa0cdf4db605576fa11e18ce6c654c6a0b6d 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( } }; string node_name; - if (debug_options_.xla_hlo_tfgraph_device_scopes() && - instruction->has_sharding() && - instruction->sharding().HasUniqueDevice()) { - node_name = StrCat( - "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); + if (debug_options_.xla_hlo_tfgraph_device_scopes()) { + auto device = instruction->sharding_unique_device(); + if (device) { + node_name = StrCat("dev", *device); + } } // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. @@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { NodeDef* node_def = graph_def_.add_node(); node_def->set_name(GetNodeNameForInstruction(instruction)); node_def->set_op(GetOpDefName(instruction)); - if (instruction->has_sharding() && - instruction->sharding().HasUniqueDevice()) { - TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice()); - node_def->set_device(GetDeviceName(device)); + + auto device = instruction->sharding_unique_device(); + if (device) { + node_def->set_device(GetDeviceName(*device)); } SetNodeAttrs(instruction, node_def); if (instruction->opcode() == HloOpcode::kFusion) { diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 4e3c9df3a036890ce25f5b14603d275263e8659b..7fd99fc93050b386c5ad24e6dcd2fea1bf652c3f 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -283,8 +283,7 @@ std::ostream& operator<<(std::ostream& out, string InstructionValueSet::ToString() const { string out = StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); - ForEachElement([this, &out](const ShapeIndex& index, - const HloValueSet& value_set) { + ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) { StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); }); return out; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c80c1e0e7db5da1f23a99def52a54785bab97df6..3fae61f704ae73bb66b151c12cf7d900ffe42f49 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -105,6 +105,15 @@ Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } +Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(hlo, + ShapeInference::InferAllToAllTupleShape(operand_shapes)); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -119,7 +128,7 @@ Status CheckIsTokenOperand(const HloInstruction* instruction, const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { return InternalError( - "Expected operand %lld to be token-shaped, actual shape is" + "Expected operand %lld to be token-shaped, actual shape is " "%s:\n%s", operand_no, ShapeUtil::HumanString(token->shape()).c_str(), instruction->ToString().c_str()); @@ -224,10 +233,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (!ShapeUtil::IsArray(reduce->shape())) { + return InvalidArgument("Variadic reduce is not supported."); + } return CheckShape( reduce, ShapeInference::InferReduceShape( - reduce->operand(0)->shape(), reduce->operand(1)->shape(), + {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); } @@ -510,6 +522,15 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + return CheckShape( + scatter, ShapeInference::InferScatterShape( + scatter->operand(0)->shape(), scatter->operand(1)->shape(), + scatter->operand(2)->shape(), + scatter->to_apply()->ComputeProgramShape(), + scatter->scatter_dimension_numbers())); +} + Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { std::vector operand_shapes; for (const HloInstruction* operand : token->operands()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 79f7aa9f4ce66cc9b53d016f2e126033492c81e9..5a56a44f355a7b8c43d22433404d80f672024a55 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -45,6 +45,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllToAll(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -83,6 +84,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index d7458c338e9f1df9fac90270845aae0b8f779ee2..bb5b40a8a87c5eab5a5b1599581a81bbd064511b 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -36,7 +36,8 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto print_op = [&](const OpInfo& op) { + int64 cumulative_cycles = 0; + auto print_op = [&](const OpInfo& op, bool is_total = false) { // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that // were expected to be free and are actually free -- things like (on most // backends) kParameter or kConstant HLOs. There's no need to clutter the @@ -59,27 +60,44 @@ string HumanReadableProfileBuilder::ToString() const { } } + double cumulative_cycles_percent = 0; double cycles_percent = 0; + if (!is_total) { + cumulative_cycles += op.cycles; + } if (total_cycles_ > 0) { cycles_percent = op.cycles / static_cast(total_cycles_) * 100; + cumulative_cycles_percent = + cumulative_cycles / static_cast(total_cycles_) * 100; + } + + string cycles_percent_str; + if (is_total) { + // Leaving off the two trailing decimal points of "100.%" lets us save two + // columns in the output. + cycles_percent_str = "100.% 100Σ"; + } else { + cycles_percent_str = + Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " - ":: %18s :: %14s :: %16s :: %s\n", - op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds < 0 - ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps( - op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + Appendf( + &s, + "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%16s :: %s\n", + op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), + op.flop_count <= 0 + ? "" + : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + op.transcendental_count <= 0 + ? "" + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) + .c_str(), + bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); }; float optimal_seconds_sum = 0.0; @@ -98,7 +116,8 @@ string HumanReadableProfileBuilder::ToString() const { VLOG(1) << "Total floating point ops: " << total_flops; print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + total_transcendentals, total_bytes, optimal_seconds_sum}, + /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 8b2df3256776a7d77517daff1fe282b0dbde7045..3531b7223fb11df212fa8d30e3adba6aac6c5679 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -447,7 +447,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, int64 indexed_source_subarray_size = std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, - operand_shape.end(), 1, std::multiplies()); + operand_shape.end(), 1LL, std::multiplies()); return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); } @@ -764,7 +764,7 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); - CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1l, + CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL, std::multiplies()), ShapeUtil::ElementsIn(scalar_indexed_source_shape)); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index af07370135ca2b2e53fcbcb53696e0aa12bf7a6f..f33942d67907d8f40811bde5041350a2e1e1f1fc 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -120,6 +120,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -141,6 +142,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 9705687b004976fc5d35ddeb1c2a69c65ed50358..b5a9d6e8e7d66ae0c560226a79578d85eaf55644 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -874,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, // HostCompute module. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. - if (!sharding.HasUniqueDevice() || - HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) { + auto device = sharding.UniqueDevice(); + if (!device || HloSharding::IsReservedDevice(*device)) { copy->set_sharding(sharding); } } @@ -1228,7 +1228,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( const PointsToSet& points_to_set = constraints->points_to_analysis().GetPointsToSet(instruction); return points_to_set.ForEachElementWithStatus( - [this, &shape_layout, constraints]( + [&shape_layout, constraints]( const ShapeIndex& index, const PointsToSet::BufferList& buffers) -> Status { if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 0573304912b021839fd9fa9d7ae6f6bd268899b7..cdd3daf73b8ac1a4d1ec3c81224c2c0bfe8e5811 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -223,3 +223,22 @@ cc_library( "@llvm//:core", ], ) + +cc_library( + name = "buffer_assignment_util", + srcs = ["buffer_assignment_util.cc"], + hdrs = ["buffer_assignment_util.h"], + deps = [ + "//tensorflow/compiler/xla/service:buffer_assignment", + ], +) + +cc_library( + name = "math_ops", + srcs = ["math_ops.cc"], + hdrs = ["math_ops.h"], + deps = [ + ":llvm_util", + "@llvm//:core", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index 2552ff4a6a06d18f34b4ba224b66d6d97ddd74d3..fe5ec1cc66d06e85ce70625ef7cf764a37b29166 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -56,12 +56,12 @@ ENTRY while3 { )"; CompileAndVerifyIr(hlo_string, R"( -; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval +; CHECK-LABEL: @body(i8* %retval ; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] -; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:.*]] +; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; -; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params -; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0 +; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 ; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] ; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* ; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..4eb5d9fb4750927ca189e02f312b2d6be7fdd418 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" + +namespace xla { +namespace llvm_ir { +static const HloInstruction& InstrForConstantBufferAllocation( + const BufferAllocation& allocation) { + CHECK(allocation.is_constant()); + HloInstruction* const_instr = nullptr; + for (const auto& buffer_offset_pair : allocation.assigned_buffers()) { + const LogicalBuffer* buffer = buffer_offset_pair.first; + // BufferAssignment may have assigned non-constant instructions to this + // allocation too so we can't CHECK this condition. E.g. for + // + // while(init = constant, body = identity, cond = ...) + // + // the LogicalBuffer for the kWhile instruction will have the same + // BufferAllocation as the LogicalBuffer for the (init) constant. + if (buffer->instruction()->opcode() == HloOpcode::kConstant) { + CHECK_EQ(const_instr, nullptr) + << const_instr->ToString() << " " << buffer->ToString(); + const_instr = buffer->instruction(); + } + } + CHECK_NE(const_instr, nullptr); + return *const_instr; +} + +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + string instr_name = InstrForConstantBufferAllocation(allocation).name(); + for (char& c : instr_name) { + if (c == '.') { + c = '_'; + } + } + return tensorflow::strings::StrCat("buffer_for_", instr_name); +} + +const Literal& LiteralForConstantAllocation( + const BufferAllocation& allocation) { + return InstrForConstantBufferAllocation(allocation).literal(); +} +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bfb6eecb87f6a1b756b3a8da3377f608dd7f0be7 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" + +namespace xla { +namespace llvm_ir { +// In XLA:GPU we map constant buffer allocations to globals in the generated +// LLVM IR. This function gives us the name of the global variable a constant +// buffer is mapped to. Not used on XLA:CPU. +string ConstantBufferAllocationToGlobalName(const BufferAllocation& allocation); + +// Returns the Literal corresponding to `allocation`, which must be a constant +// allocation. +const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 1bd73fc793a3005ce69b53db7ef2c19cd06c633e..27fbb11e2ede66a1268e7e949634b2c7d29cbc1c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -56,10 +56,6 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to officially document different behavior. llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size); llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); start_index[i] = diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 7a9170f3794159e887d9eeacb248cd5dcff2c320..2b6caee6aa72f426cf85c8c56c3ef500ff8c5d3d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -327,6 +327,7 @@ llvm::Value* IrArray::Index::Linearize( llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. + CHECK_EQ(size(), dimensions.size()); llvm::Value* logical_linear_index = GetConstantWithIndexType(0); int64 multiplier = 1; for (ssize_t i = size() - 1; i >= 0; --i) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index e4f65bd427a5aab09292c6a4cb2586d688cd115d..e6126881af8b8123e08a4eaa934b52a7fd378ce6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -657,5 +657,56 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { } } +std::pair UMulLowHigh32(llvm::IRBuilder<>* b, + llvm::Value* src0, + llvm::Value* src1) { + CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32); + CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32); + llvm::Type* int64_ty = b->getInt64Ty(); + src0 = b->CreateZExt(src0, int64_ty); + src1 = b->CreateZExt(src1, int64_ty); + return SplitInt64ToInt32s(b, b->CreateMul(src0, src1)); +} + +std::pair SplitInt64ToInt32s( + llvm::IRBuilder<>* b, llvm::Value* value_64bits) { + CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64); + llvm::Type* int32_ty = b->getInt32Ty(); + llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty); + llvm::Value* high_32bits = + b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty); + return std::make_pair(low_32bits, high_32bits); +} + +llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState( + llvm::Module* module, llvm::IRBuilder<>* b) { + static const char* kPhiloxRngStateVariableName = "philox_rng_state"; + llvm::GlobalVariable* state_ptr = + module->getNamedGlobal(kPhiloxRngStateVariableName); + if (!state_ptr) { + state_ptr = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/b->getInt64Ty(), + /*isConstant=*/false, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/b->getInt64(0), + /*Name=*/kPhiloxRngStateVariableName); + } + return state_ptr; +} + +void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module, + llvm::IRBuilder<>* builder) { + llvm::GlobalVariable* state_ptr = + GetOrCreateVariableForPhiloxRngState(module, builder); + llvm::Value* state_value_old = builder->CreateLoad(state_ptr, "load_state"); + // If the 64-bit value overflows, we use the wraparound value. This should + // be fine in practice as we only add one to the value each time when a RNG is + // executed. + llvm::Value* state_value_new = builder->CreateAdd( + state_value_old, builder->getInt64(value), "inc_state"); + builder->CreateStore(state_value_new, state_ptr); +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index d8746ffe01189aa43248d5b0ca4966afadcf70a6..09583985342033d486d50910b6f5ca732a9a3756 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -292,6 +292,27 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type, // don't start with xla_ to LLVM. void InitializeLLVMCommandLineOptions(const HloModuleConfig& config); +// Zero-extends two 32-bit values to 64 bits, multiplies them, and returns the +// result as a pair of (low 32 bits, high 32 bits). +std::pair UMulLowHigh32(llvm::IRBuilder<>* b, + llvm::Value* src0, + llvm::Value* src1); +// Splits the 64-bit integer value into its high and low 32 bits. +std::pair SplitInt64ToInt32s( + llvm::IRBuilder<>* b, llvm::Value* value_64bits); + +// Checks whether a global variable is already created to represent a +// state passed between RNG calls implemented with Philox algorithm. If not, +// creates such a variable. Returns the global variable. +llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState( + llvm::Module* module, llvm::IRBuilder<>* b); + +// Adds a value to the global state variable each time when a RNG hlo is +// executed. The value of this global state variable is added to the seed +// of the Philox RNG algorithm so that calling the same RNG Hlo multiple times +// should rarely produce the same result. +void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module, + llvm::IRBuilder<>* b); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e115cdabf4b290617700276dba8f2e5648a7c07 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace llvm_ir { + +llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) { + llvm::Type* type = input->getType(); + + // Clamp the input to [-9, 9]. + llvm::Value* input_clamped = llvm_ir::EmitFloatMin( + llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b), + llvm::ConstantFP::get(type, 9.0), b); + + static constexpr std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + llvm::Value* input_squared = b->CreateFMul(input_clamped, input_clamped); + llvm::Value* numerator = llvm::ConstantFP::get(type, numerator_coeffs[0]); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = b->CreateFAdd(b->CreateFMul(input_squared, numerator), + llvm::ConstantFP::get(type, numerator_coeffs[i])); + } + + numerator = b->CreateFMul(input_clamped, numerator); + + llvm::Value* denominator = llvm::ConstantFP::get(type, denominator_coeffs[0]); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = + b->CreateFAdd(b->CreateFMul(input_squared, denominator), + llvm::ConstantFP::get(type, denominator_coeffs[i])); + } + + return b->CreateFDiv(numerator, denominator); +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/pool_test.cc b/tensorflow/compiler/xla/service/llvm_ir/math_ops.h similarity index 54% rename from tensorflow/compiler/xla/service/pool_test.cc rename to tensorflow/compiler/xla/service/llvm_ir/math_ops.h index 8c4fe258e38fff1b2086d8809bfc487e11ef713f..6c8bc3a076367eae2f1829966be2872e5f258178 100644 --- a/tensorflow/compiler/xla/service/pool_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,28 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/pool.h" +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ -#include "tensorflow/compiler/xla/test_helpers.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" namespace xla { -namespace { +namespace llvm_ir { -using PoolTest = ::testing::Test; +// Emits an approximation of tanh. The implementation uses the same rational +// interpolant as implemented in Eigen3. +llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input); -TEST_F(PoolTest, Test) { - Pool pool; - - { - auto ptr = pool.Allocate(); - EXPECT_NE(nullptr, ptr.get()); - *ptr = 5; - } - - auto ptr = pool.Allocate(); - EXPECT_NE(nullptr, ptr.get()); - EXPECT_EQ(5, *ptr); -} - -} // namespace +} // namespace llvm_ir } // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 6f261c32f4181a6c4107f7fbcf782feb4347e587..e546f5cc4ae305b40c1bdbcae090daadee11241b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -38,19 +39,18 @@ namespace llvm_ir { namespace { // Adds the inner comparison loop where we compare elements pointed to by // 'keys_index' and 'compare_keys_index'. -void EmitCompareLoop(int64 dimension_to_sort, - const llvm_ir::IrArray::Index& keys_index, - const llvm_ir::IrArray::Index& compare_keys_index, - const llvm_ir::IrArray& keys_array, llvm::IRBuilder<>* b) { - // TODO(b/26783907): parallelize this loop. - +void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, + const IrArray::Index& compare_keys_index, + const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, + llvm::IRBuilder<>* b) { // if (is_smaller_index && // compare_keys[dimension_to_sort] < dimension_to_sort_bound) llvm::Value* is_smaller_index = b->CreateICmpSLT( keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); int64 dimension_to_sort_bound = keys_array.GetShape().dimensions(dimension_to_sort); - auto if_data = llvm_ir::EmitIfThenElse( + auto if_data = EmitIfThenElse( b->CreateAnd(is_smaller_index, b->CreateICmpSLT(compare_keys_index[dimension_to_sort], keys_index.GetConstantWithIndexType( @@ -63,30 +63,36 @@ void EmitCompareLoop(int64 dimension_to_sort, auto comparison = primitive_util::IsFloatingPointType(key_type) // TODO(b/26783907): Figure out how to handle NaNs. - ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2) + ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1) : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type) ? llvm::ICmpInst::ICMP_SLT : llvm::ICmpInst::ICMP_ULT, - key1, key2); - auto min_key = b->CreateSelect(comparison, key1, key2); - auto max_key = b->CreateSelect(comparison, key2, key1); - keys_array.EmitWriteArrayElement(keys_index, min_key, b); - keys_array.EmitWriteArrayElement(compare_keys_index, max_key, b); + key2, key1); + // If key2 < key1 + auto if_smaller_data = + EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); + SetToFirstInsertPoint(if_smaller_data.true_block, b); + // Swap key1 with key2. + keys_array.EmitWriteArrayElement(keys_index, key2, b); + keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); + if (values_array.has_value()) { + // Also swap the values. + auto value1 = values_array.value().EmitReadArrayElement(keys_index, b); + auto value2 = + values_array.value().EmitReadArrayElement(compare_keys_index, b); + values_array.value().EmitWriteArrayElement(keys_index, value2, b); + values_array.value().EmitWriteArrayElement(compare_keys_index, value1, b); + } } } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, tensorflow::StringPiece name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { const Shape& keys_shape = keys_array.GetShape(); - // TODO(b/26783907): This case can probably be avoided with the Algebraic - // Simplifier. - if (ShapeUtil::IsScalar(keys_shape)) { - return Status::OK(); - } - // Create loop nests which loop through the operand dimensions. The sort // dimension is handled in the innermost loop which performs the sorting. ForLoopNest loop_nest(name, b); @@ -131,7 +137,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, compare_keys_index[dimension_to_sort] = b->CreateXor(compare_index[0], xor_mask); EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, - keys_array, b); + keys_array, values_array, b); return Status::OK(); }; if (launch_dimensions != nullptr) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index e75f9b08fbba7c79b8354698ad17e79c154bd67e..8458744c6bc0e50a1c1cc8d3e66e29c7d4f74d73 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -30,6 +31,7 @@ namespace llvm_ir { // implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, tensorflow::StringPiece name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 53efc30c3653879709fceae3dcdd4f679740f622..5e02096ee501b23a7976a50f13bb7e7f3c5e2d34 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 39d6734c3fc06df6832cf67edddbc7c14c815cd1..8f707ea9046a00a15cac469672a7a992f20bf483 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" diff --git a/tensorflow/compiler/xla/service/pool.h b/tensorflow/compiler/xla/service/pool.h deleted file mode 100644 index 8e710ebb6dc17e0e204ba6ab3c6c159627cd9d3b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/pool.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_POOL_H_ - -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/core/platform/mutex.h" - -namespace xla { - -// Pool of values, which are created as needed and destroyed when the `Pool` is -// destroyed -template -class Pool { - public: - struct Deleter { - void operator()(T* ptr) { pool->Deallocate(ptr); } - - Pool* pool; - }; - - // A pointer to a taken element of a `Pool` which returns it to the pool on - // destruction - using SmartPtr = std::unique_ptr; - - // Constructs a `Pool` with given factory function, which need not be - // thread-safe. - explicit Pool(std::function()> factory) - : factory_(factory) {} - - explicit Pool() : Pool([]() { return MakeUnique(); }) {} - - // Returns a pointer to a value in the pool, creating a new value if none is - // free. The returned smart pointer returns the element to the pool on - // destruction. - // - // This method is thread-safe. - SmartPtr Allocate() { - tensorflow::mutex_lock lock(mu_); - T* ptr; - if (!xs_.empty()) { - ptr = std::move(xs_.back()).release(); - xs_.pop_back(); - } else { - ptr = factory_().release(); - } - Deleter del = {this}; - return std::unique_ptr(ptr, del); - } - - private: - // Puts a pointer to a value back into the pool, leaving it free for future - // use. - // - // This method is thread-safe. - void Deallocate(T* ptr) { - tensorflow::mutex_lock lock(mu_); - xs_.push_back(std::unique_ptr(ptr)); - } - - const std::function()> factory_ GUARDED_BY(mu_); - std::vector> xs_ GUARDED_BY(mu_); - tensorflow::mutex mu_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_POOL_H_ diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 636013cbb561f8506e173bd634e07b48a8dc570e..433560e32258a01c10ade85ec2faf52884271497 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -55,7 +56,6 @@ limitations under the License. using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; -using ::xla::source_map_util::InvalidParameterArgument; namespace xla { @@ -376,7 +376,7 @@ Service::ExecuteParallelAndRegisterResult( ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. - std::vector::SmartPtr> streams; + std::vector streams; std::vector> timers; // Global data handles for the computation results, one for each computation. @@ -403,7 +403,7 @@ Service::ExecuteParallelAndRegisterResult( CHECK_EQ(replicas.size(), arguments[i].size()); std::vector result_buffers; for (int64 replica = 0; replica < replicas.size(); ++replica) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); @@ -515,13 +515,13 @@ StatusOr Service::ExecuteAndRegisterResult( arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. - std::vector::SmartPtr> streams; + std::vector streams; TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, SingleComputationDeviceHandle())); TF_RET_CHECK(!replicas.empty()); for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } @@ -533,7 +533,7 @@ StatusOr Service::ExecuteAndRegisterResult( // Set up run options. std::vector run_options; - for (const Pool::SmartPtr& stream : streams) { + for (const StreamPool::Ptr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); options.set_device_ordinal(stream->parent()->device_ordinal()); @@ -1052,11 +1052,12 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - Literal literal; + auto literal = Literal::CreateFromShape(arg->shape_with_layout()); + TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), &literal)); - *result->mutable_literal() = literal.ToProto(); + executor, arg->shape_with_layout(), *literal)); + *result->mutable_literal() = literal->ToProto(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 7f3910cdb0366078b97fb5f6a2dc498b37570926..dbfed628bfcabffe66bef41a82e0e2430897d80d 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_EXECUTABLE_RUN_OPTIONS_H_ #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -27,8 +27,7 @@ namespace xla { // data, now only a stream cache for GPU backend. class ServiceExecutableRunOptions { public: - using StreamBorrower = - std::function::SmartPtr>(int)>; + using StreamBorrower = std::function(int)>; ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} @@ -51,7 +50,7 @@ class ServiceExecutableRunOptions { // Borrows a stream and returns a smart pointer which returns the stream on // destruction. - StatusOr::SmartPtr> BorrowStream(int device_ordinal) const { + StatusOr BorrowStream(int device_ordinal) const { return borrow_stream_ ? borrow_stream_(device_ordinal) : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 35df792b07022b2338fcecc25eb8a0718626e464..a4ea2b28f4dbf41d61702f1af2d65c4d2c86d578 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -58,66 +58,101 @@ Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { return Status::OK(); } -Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { - if (reducer_shape.parameters_size() != 2) { - return InvalidArgument( - "Reduction function must take 2 parameters, but " +Status VerifyReducerShape( + const ProgramShape& reducer_shape, + tensorflow::gtl::ArraySlice init_value_shapes, + tensorflow::gtl::ArraySlice input_element_types, + int64 inputs) { + if (reducer_shape.parameters_size() != inputs * 2) { + return InvalidArgument( + "Reduction function must take %lld parameters, but " "takes %d parameter(s).", - reducer_shape.parameters_size()); + inputs * 2, reducer_shape.parameters_size()); } const Shape& accumulator_shape = reducer_shape.result(); - if (!ShapeUtil::IsArray(accumulator_shape) || - ShapeUtil::Rank(accumulator_shape) != 0) { - return InvalidArgument( - "Reduction function must produce a scalar but has shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); - } - - // Check that the accumulator can be passed in as the first argument. - // Note: comparing here and below with Compatible since we don't care about - // layout in scalars - see b/26668201 for a longer-term vision. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) { + std::vector accumulator_subshapes; + if (ShapeUtil::IsArray(accumulator_shape)) { + if (inputs != 1) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but " + "produces a scalar", + inputs); + } + accumulator_subshapes.push_back(&accumulator_shape); + } else if (ShapeUtil::IsTuple(accumulator_shape)) { + if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but has " + "%lld elements", + inputs, ShapeUtil::TupleElementCount(accumulator_shape)); + } + for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { + accumulator_subshapes.push_back(&element_shape); + } + } else { return InvalidArgument( - "Reduction function's first parameter shape differs from the " - "result shape: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(), + "Reduction function must produce a scalar or tuple of scalars, but has " + "shape: %s", ShapeUtil::HumanString(accumulator_shape).c_str()); } - // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - init_value_shape)) { - return InvalidArgument( - "Reduction function's accumulator shape differs from the " - "init_value shape: %s vs %s", - ShapeUtil::HumanString(accumulator_shape).c_str(), - ShapeUtil::HumanString(init_value_shape).c_str()); - } - - // Check that the inputs can be passed in as the second argument. - const Shape& input_element_shape = - ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape differs from the " - "input type element type: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + for (const Shape* element_shape : accumulator_subshapes) { + if (ShapeUtil::Rank(*element_shape) != 0) { + return InvalidArgument( + "Reduction function must return a scalar or tuple of scalars but " + "returns shape: %s", + ShapeUtil::HumanString(accumulator_shape).c_str()); + } } - // Currently the accumulator and inputs must be the same type, - // though that restriction could be relaxed. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape must " - "match the result shape, but got %s vs %s.", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(accumulator_shape).c_str()); + for (int64 i = 0; i < inputs; ++i) { + // Check that the accumulator can be passed in as the first argument. + // Note: comparing here and below with Compatible since we don't care about + // layout in scalars - see b/26668201 for a longer-term vision. + if (!ShapeUtil::Compatible(*accumulator_subshapes[i], + reducer_shape.parameters(i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "result shape: %s vs %s", + i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } + // Check that init_value's shapes are suitable for reducer_shape. + if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], + *init_value_shapes[i])) { + return InvalidArgument( + "Reduction function's accumulator shape at index %lld differs from " + "the init_value shape: %s vs %s", + i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), + ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + } + // Check that the inputs can be passed in as the non-accumulator arguments. + const Shape input_element_shape = + ShapeUtil::MakeShape(input_element_types[i], {}); + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + input_element_shape, reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "input type element type: %s vs %s", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(input_element_shape).c_str()); + } + // Check that the accumulator and inputs to the reducer function match. + // If the accumulator is scalar, it must have the same type as the inputs + // (up to fp precision). If it is a tuple, then the k-th element of the + // tuple must have the same type as the K-th input (again, up to fp + // precision.) + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape must " + "match the result shape, but got %s vs %s.", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } } return Status::OK(); @@ -1744,11 +1779,83 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(operand_shape_values); } +/* static */ StatusOr ShapeInference::InferAllToAllShape( + const Shape& shape, int64 split_dimension, int64 concat_dimension, + int64 split_count) { + TF_RET_CHECK(split_count > 0); + if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + return InvalidArgument( + "AllToAll split_dimension %lld is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + return InvalidArgument( + "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (shape.dimensions(split_dimension) % split_count != 0) { + return InvalidArgument( + "AllToAll split dimension size %lld must be dividable by split_count " + "%lld.", + shape.dimensions(split_dimension), split_count); + } + std::vector new_dimensions(shape.dimensions().begin(), + shape.dimensions().end()); + new_dimensions[split_dimension] /= split_count; + new_dimensions[concat_dimension] *= split_count; + return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); +} + +/* static */ StatusOr ShapeInference::InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice operand_shapes) { + // An Alltoall HLO instruction receives N operands (with the same shape) and + // returns a tuple that contains N array shapes. + TF_RET_CHECK(!operand_shapes.empty()); + for (int i = 0; i < operand_shapes.size(); i++) { + if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) { + return InvalidArgument( + "HLO all-to-all has operands with different shapes: the 0th " + "operand shape %s, but the %dth operand has shape %s.", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, + ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + } + } + + return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); +} + /* static */ StatusOr ShapeInference::InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, const ProgramShape& to_apply) { - // Check that the dimension to reduce are in-bounds for the given shape. + if (arg_shapes.empty()) { + return InvalidArgument("Reduce must have at least 2 arguments, has 0"); + } + if (arg_shapes.size() % 2) { + return InvalidArgument( + "Reduce must have an even number of arguments, has %lu", + arg_shapes.size()); + } + int64 num_reduced_args = arg_shapes.size() / 2; + + tensorflow::gtl::ArraySlice reduced_args(arg_shapes, 0, + num_reduced_args); + // Check that all of the reduced tensors have the same dimensions. The element + // types may be different. + for (int64 i = 1; i < num_reduced_args; ++i) { + if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { + return InvalidArgument( + "All reduced tensors must have the sime dimension. Tensor 0 has " + "shape %s, Tensor %lld has shape %s", + ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, + ShapeUtil::HumanString(*reduced_args[i]).c_str()); + } + } + + // Check that the dimensions to reduce are in-bounds for the given shape. + // We've already verified all reduced tensors have the same dimensions, so it + // doesn't matter which one we choose. + const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { return InvalidArgument( @@ -1756,8 +1863,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(arg).c_str()); } } - TF_RETURN_IF_ERROR( - VerifyReducerShape(to_apply, init_value, arg.element_type())); + + tensorflow::gtl::ArraySlice init_values( + arg_shapes, num_reduced_args, arg_shapes.size()); + std::vector element_types; + for (const Shape* arg : reduced_args) { + element_types.push_back(arg->element_type()); + } + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types, + num_reduced_args)); std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); @@ -1768,15 +1882,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } - return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions); + if (ShapeUtil::IsScalar(to_apply.result())) { + return ShapeUtil::MakeShape(to_apply.result().element_type(), + new_dimensions); + } else { + std::vector result_subshapes; + for (const Shape& subshape : to_apply.result().tuple_shapes()) { + result_subshapes.push_back( + ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + } + return ShapeUtil::MakeTupleShape(result_subshapes); + } } /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, - operand_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {operand_shape.element_type()}, + /*inputs=*/1)); return InferWindowOutputShape(operand_shape, window, init_value_shape.element_type(), /*allow_negative_padding=*/false); @@ -1821,8 +1946,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Check if the scatter function has a proper shape as a reduction. - TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape, - source_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape}, + {source_shape.element_type()}, + /*inputs=*/1)); // Check if the result shape of window operation matches the source shape. TF_ASSIGN_OR_RETURN(const Shape& window_result_shape, @@ -2568,4 +2694,194 @@ static Status ValidateGatherDimensionNumbers( return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); } +namespace { + +Status ValidateScatterDimensionNumbers( + const Shape& operand_shape, + tensorflow::gtl::ArraySlice scatter_indices_shape, + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + // Validate update_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.update_window_dims())) { + return InvalidArgument( + "update_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.update_window_dims()) != + dim_numbers.update_window_dims().end()) { + return InvalidArgument( + "update_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + const int64 updates_rank = ShapeUtil::Rank(updates_shape); + for (int64 window_dim : dim_numbers.update_window_dims()) { + if (window_dim < 0 || window_dim >= updates_rank) { + return InvalidArgument( + "Invalid update_window_dims set in scatter op; valid range is [0, " + "%lld). got: %lld.", + updates_rank, window_dim); + } + } + + // Validate inserted_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + return InvalidArgument( + "inserted_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + dim_numbers.inserted_window_dims().end()) { + return InvalidArgument( + "inserted_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { + if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid inserted_window_dims set in scatter op; valid range is [0, " + "%d), got: %lld.", + operand_shape.dimensions_size(), inserted_dim); + } + } + + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. + if (dim_numbers.scatter_dims_to_operand_dims_size() != + scatter_indices_shape[dim_numbers.index_vector_dim()]) { + return InvalidArgument( + "Scatter op has %d elements in scatter_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "These two numbers must be equal.", + dim_numbers.scatter_dims_to_operand_dims_size(), + dim_numbers.index_vector_dim(), + scatter_indices_shape[dim_numbers.index_vector_dim()]); + } + for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) { + int64 scatter_dim_to_operand_dim = + dim_numbers.scatter_dims_to_operand_dims(i); + if (scatter_dim_to_operand_dim < 0 || + scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld.", + operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); + } + } + std::vector sorted_scatter_dims_to_operand_dims( + dim_numbers.scatter_dims_to_operand_dims().begin(), + dim_numbers.scatter_dims_to_operand_dims().end()); + c_sort(sorted_scatter_dims_to_operand_dims); + if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + sorted_scatter_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " + "got: %s.", + Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +} // namespace + +/*static*/ StatusOr ShapeInference::InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers) { + TF_RETURN_IF_ERROR( + ExpectArray(operand_shape, "operand tensor of scatter op")); + TF_RETURN_IF_ERROR( + ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); + TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); + + if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { + return InvalidArgument( + "Scatter indices parameter must be an integral tensor; got %s.", + ShapeUtil::HumanString(scatter_indices_shape).c_str()); + } + + if (scatter_indices_shape.dimensions_size() < + scatter_dim_numbers.index_vector_dim() || + scatter_dim_numbers.index_vector_dim() < 0) { + return InvalidArgument( + "Scatter index leaf dimension must be within [0, rank(scatter_indices)" + " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " + "is %lld.", + scatter_indices_shape.dimensions_size(), + scatter_dim_numbers.index_vector_dim()); + } + + // Check if the update computation has a proper shape as a reduction. + const Shape init_value_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), {}); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {updates_shape.element_type()}, + /*inputs=*/1)); + + std::vector expanded_scatter_indices_shape = + ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions())); + if (expanded_scatter_indices_shape.size() == + scatter_dim_numbers.index_vector_dim()) { + expanded_scatter_indices_shape.push_back(1); + } + + int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + + scatter_dim_numbers.update_window_dims_size(); + if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + expected_updates_rank, + ShapeUtil::Rank(updates_shape)); + } + + TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( + operand_shape, expanded_scatter_indices_shape, updates_shape, + scatter_dim_numbers)); + + int64 inserted_dims_seen = 0; + std::vector max_update_window_bounds; + for (int i = 0; i < operand_shape.dimensions_size(); ++i) { + if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && + scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { + ++inserted_dims_seen; + } else { + max_update_window_bounds.push_back(operand_shape.dimensions(i)); + } + } + for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { + auto update_window_dim = scatter_dim_numbers.update_window_dims(i); + if (updates_shape.dimensions(update_window_dim) > + max_update_window_bounds[i]) { + return InvalidArgument( + "Bounds of the window dimensions of updates must not exceed the " + "bounds of the corresponding dimensions of operand. For dimension " + "%lld, updates bound is %lld, operand bound is %lld.", + update_window_dim, updates_shape.dimensions(update_window_dim), + max_update_window_bounds[i]); + } + } + + int64 scatter_dims_seen = 0; + for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + bool is_update_window_dim = + c_binary_search(scatter_dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + continue; + } + if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) { + ++scatter_dims_seen; + } + if (updates_shape.dimensions(i) != + expanded_scatter_indices_shape[scatter_dims_seen]) { + return InvalidArgument( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices. For " + "scatter dimension %lld, updates bound is %lld, scatter_indices " + "bound is %lld.", + i, updates_shape.dimensions(i), + expanded_scatter_indices_shape[scatter_dims_seen]); + } + ++scatter_dims_seen; + } + + return operand_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 1a5684e3c306eef90fd1bfdf4565b0dcde2fbab6..c185b0a1bd79e23e0d76daad50fb4a9708a743dd 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -119,11 +119,22 @@ class ShapeInference { const Shape& in, FftType fft_type, tensorflow::gtl::ArraySlice fft_length); - // Infers the shape produced a cross replica sum with the given operand + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferCrossReplicaSumShape( tensorflow::gtl::ArraySlice operand_shapes); + // Infers final shape of an Alltoall operation that is created by the xla + // builder. + static StatusOr InferAllToAllShape(const Shape& shape, + int64 split_dimension, + int64 concat_dimension, + int64 split_count); + + // Infers the shape of an HLO all-to-all instruction. + static StatusOr InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice operand_shapes); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // @@ -131,7 +142,7 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, const ProgramShape& to_apply); @@ -268,6 +279,14 @@ class ShapeInference { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Helper that validates the given input shape, scatter indices shape, updates + // shape, and scatter dimension numbers that constitute a scatter operation, + // and returns the result shape of the scatter operation. + static StatusOr InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 6046d50c6d41a3956b996a3320848784ffd59068..a73fa181cdd13dc7fabcdc367ae117e19bdc3e5f 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -63,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { tensorflow::gtl::ArraySlice dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( - arg, f32_, dimensions_to_reduce, to_apply); + {&arg, &f32_}, dimensions_to_reduce, to_apply); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, inferred_status.ValueOrDie())); @@ -703,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { /*dimensions_to_reduce=*/{0, 1, 2}); } +TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), + inferred_status.ValueOrDie())); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, + ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr( + "parameter shape differs from the result shape: s32[] vs f32[]")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must have at least 2 arguments, has 0")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("accumulator shape at index 0 differs from the " + "init_value shape: s32[] vs f32[]")); +} + TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape( - ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, - to_apply); + {&arg_shape, &f32_}, + /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("out-of-bounds dimension")); @@ -715,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), @@ -725,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - HasSubstr("first parameter shape differs")); + HasSubstr("0-th parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { @@ -1536,7 +1626,7 @@ TEST_F(ShapeInferenceTest, BadSort) { << statusor.status(); } -class GatherShapeInferenceTest : public ShapeInferenceTest { +class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); @@ -1553,9 +1643,13 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); + const ProgramShape to_apply_ = + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); }; -TEST_F(GatherShapeInferenceTest, TensorFlowGather) { +// Shape inference tests for Gather. + +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, @@ -1570,7 +1664,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, @@ -1585,7 +1679,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, @@ -1600,7 +1694,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( @@ -1617,7 +1711,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( @@ -1635,7 +1729,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( @@ -1653,7 +1747,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { +TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, @@ -1671,7 +1765,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { +TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { // The gather indices "tensor" is a scalar S here that's used to slice out // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, @@ -1689,7 +1783,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1704,7 +1798,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1719,7 +1813,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1734,7 +1828,7 @@ TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1751,7 +1845,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1768,7 +1862,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1784,7 +1878,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1800,7 +1894,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1818,7 +1912,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1835,7 +1929,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1853,7 +1947,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1872,7 +1966,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1890,7 +1984,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1908,7 +2002,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1924,7 +2018,8 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { +TEST_F(ScatterGatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1940,7 +2035,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1958,7 +2053,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1975,7 +2070,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1992,5 +2087,575 @@ TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { << statusor.status(); } +// Shape inference tests for Scatter. + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {64, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {10, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 8}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndices) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndicesV2) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterNdWithUpdatesNotMatchingIndices) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) { + // This is equivalent to a dynamic update slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3, 4}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) { + // The scalar indices "tensor" is a scalar S here that's used to update a + // [30,29,28,27] shaped tensor within the operand at position S. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for operand")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + ScatterWithTupleShapedScatterIndicesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for scatter indices")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for updates")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/10)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter index leaf dimension must be within [0, " + "rank(scatter_indices) + 1)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Updates tensor must be of rank 7; got 8.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) { + const ProgramShape invalid_update_computation = + ShapeUtil::MakeProgramShape({f32_}, f32_); + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), + invalid_update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Reduction function must take 2 parameters, but takes 1")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 8, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 9}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid update_window_dims set in scatter op; valid " + "range is [0, 9)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{2, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 5}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid inserted_window_dims set in scatter op; valid " + "range is [0, 5)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of scatter_indices " + "is 5. These two numbers must be equal")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain " + "is [0, 5), got: 4->10")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0582c6a2d3a05e2ed5aead5faac54e536d350cd --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stream_pool.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { + std::unique_ptr stream; + { + tensorflow::mutex_lock lock(mu_); + if (!streams_.empty()) { + // Re-use an existing stream from the pool. + stream = std::move(streams_.back()); + streams_.pop_back(); + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } + } + + if (!stream) { + // Create a new stream. + stream = MakeUnique(executor); + stream->Init(); + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool created new stream"; + } + + // Return the stream wrapped in Ptr, which has our special deleter semantics. + PtrDeleter deleter = {this}; + return Ptr(stream.release(), deleter); +} + +void StreamPool::ReturnStream(se::Stream* stream) { + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool returning ok stream"; + tensorflow::mutex_lock lock(mu_); + streams_.emplace_back(stream); + } else { + // If the stream has encountered any errors, all subsequent operations on it + // will fail. So just delete the stream, and rely on new streams to be + // created in the future. + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool deleting !ok stream"; + delete stream; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stream_pool.h b/tensorflow/compiler/xla/service/stream_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..7221d323a61593ac4b203a81b6046d81a5beaaf0 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool.h @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ + +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Pool of stream_executor::Streams, which are created as needed and +// destroyed when the pool is destroyed. +class StreamPool { + public: + struct PtrDeleter { + void operator()(se::Stream* stream) { pool->ReturnStream(stream); } + StreamPool* pool; + }; + + // Stream pointer type returned by BorrowStream, which returns the + // stream to the pool on destruction. + using Ptr = std::unique_ptr; + + StreamPool() {} + + // Returns a pointer to a stream in the pool, creating a new stream + // if none are available in the pool. The returned smart pointer + // returns the stream to the pool on destruction. + // + // This method is thread-safe. + Ptr BorrowStream(se::StreamExecutor* executor); + + private: + // Puts a pointer to a stream back into the pool, leaving it free + // for future use. Streams that have previously encountered errors + // are deleted, and not returned to the pool. + // + // This method is thread-safe. + void ReturnStream(se::Stream* stream); + + tensorflow::mutex mu_; + std::vector> streams_ GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..aaf5c37b0d250f78cb57639255ac9b59e1b462f7 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stream_pool.h" + +#include + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace { + +class StreamPoolTest : public ::testing::Test { + protected: + std::unique_ptr NewStreamExecutor() { + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie(); + se::StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).ConsumeValueOrDie(); + } +}; + +TEST_F(StreamPoolTest, EmptyPool) { StreamPool pool; } + +TEST_F(StreamPoolTest, OneStreamPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow and return a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + se::Stream* stream1_ptr = stream1.get(); + EXPECT_TRUE(stream1->ok()); + stream1 = nullptr; + + // Borrow and return another stream. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + stream2 = nullptr; + + // The underlying streams should be the same, since stream1 was the + // only stream available in the pool when stream2 was borrowed. + EXPECT_EQ(stream1_ptr, stream2_ptr); +} + +TEST_F(StreamPoolTest, TwoStreamPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow two streams. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + se::Stream* stream1_ptr = stream1.get(); + EXPECT_TRUE(stream1->ok()); + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different, since we haven't + // returned either of them yet. + EXPECT_NE(stream1_ptr, stream2_ptr); + + // Return stream1 and borrow stream3. + stream1 = nullptr; + StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + se::Stream* stream3_ptr = stream3.get(); + EXPECT_TRUE(stream3->ok()); + + // stream1 and stream3 should be the same. + EXPECT_EQ(stream1_ptr, stream3_ptr); + EXPECT_NE(stream2_ptr, stream3_ptr); + + // Return stream2, and borrow stream4. + stream2 = nullptr; + StreamPool::Ptr stream4 = pool.BorrowStream(executor.get()); + se::Stream* stream4_ptr = stream4.get(); + EXPECT_TRUE(stream4->ok()); + + // Stream2 and stream4 should be the same. + EXPECT_EQ(stream2_ptr, stream4_ptr); + EXPECT_NE(stream3_ptr, stream4_ptr); +} + +TEST_F(StreamPoolTest, BadStreamDiscarded) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Force an error on the stream; here we call a method that requires + // DNN support, which we know the Host platform doesn't support. + stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1->ok()); + + // Return stream1 and borrow stream2. + stream1 = nullptr; + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. + + // Return stream2 and borrow stream3. The previous error on stream1 + // has no effect on these streams, and they are the same. + stream2 = nullptr; + StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + se::Stream* stream3_ptr = stream3.get(); + EXPECT_TRUE(stream3->ok()); + EXPECT_EQ(stream2_ptr, stream3_ptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 7232c658b3f0687ac93a83e46a200f88bf202084..32d368a90429ec026120bdf033957617eeaba23e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -43,15 +43,39 @@ TransferManager::GetPlatformTransferManagers() { StatusOr> TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { StatusOr> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferLiteralFromDevice(substream, device_buffer, - [&](StatusOr> arg) { - ret = std::move(arg); + Status s; + Literal literal(device_buffer.on_host_shape()); + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + s = status; + n.Notify(); + }); + n.WaitForNotification(); + if (!s.ok()) { + return s; + } + return MakeUnique(std::move(literal)); +} + +Status TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal) { + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + Status ret; + tensorflow::Notification n; + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + ret = status; n.Notify(); }); n.WaitForNotification(); @@ -76,22 +100,27 @@ Status TransferManager::TransferLiteralToDevice( StatusOr> TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { + StatusOr> ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - StatusOr> ret; se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferArrayFromDevice(substream, shape, source, - [&](StatusOr> arg) { - ret = std::move(arg); + Literal literal(shape); + Status s; + TransferArrayFromDevice(substream, shape, source, literal, + [&](Status status) { + s = status; n.Notify(); }); n.WaitForNotification(); - return ret; + if (!s.ok()) { + return s; + } + return MakeUnique(std::move(literal)); } Status TransferManager::TransferArrayToDevice( @@ -130,7 +159,7 @@ Status TransferManager::TransferArrayToDeviceAsync( void TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, - std::function>)> done) { + const MutableBorrowingLiteral& literal, std::function done) { if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", @@ -147,7 +176,8 @@ void TransferManager::TransferArrayFromDevice( stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done)); + return TransferLiteralFromDevice(stream, shaped_buffer, literal, + std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 82c599e482d85fc5bbe5a5a48c6c6b053186803b..475a2e5c141d66fa689fb402da1ee81fb4ab80f7 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -59,6 +59,9 @@ class TransferManager { // This function should be avoided in favor of the asynchronous version below. virtual StatusOr> TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); + virtual Status TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal); // Begins transferring a literal containing the data held in the given // ShapedBuffer using the provided executor. @@ -69,9 +72,10 @@ class TransferManager { // // device_buffer is copied by reference and must live at least until done() is // invoked. - virtual void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) = 0; + virtual void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -101,10 +105,10 @@ class TransferManager { // transfer an array at a known address. Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - void TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source, - std::function>)> done); + void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, + std::function done); Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, @@ -120,9 +124,9 @@ class TransferManager { // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + MutableBorrowingLiteral literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 7051a4cf51749d294478cf9a34d4700cb52ae312..58f767e913fbc0023e0c45a4f0e82ecefeeef2d6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 990dfc410ccf6ab84af00f4a16dc783c11985844..0447807a41b8b32ee297e1ca94393da8c687c5e6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -232,8 +232,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( // Copy the points-to set (and tuple sources) at index {element_index} of the // operand to the points-to set for this GetTupleElement instruction. points_to_set.ForEachMutableElement( - [&, this](const ShapeIndex& target_index, - PointsToSet::BufferList* points_to) { + [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) { // Construct an index into the operand by prepending element_index to // the index for the GetTupleElement instruction's points-to set. ShapeIndex src_index; @@ -308,7 +307,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // Recursively copy the points to set of the operand tuple {0} to the output // element {0}. points_to_set.ForEachMutableElement( - [this, &points_to_set, &operand_points_to_set]( + [&points_to_set, &operand_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { if (index.empty() || index[0] != 0) { return; @@ -517,7 +516,7 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction( const HloInstruction* instruction, TuplePointsToAnalysis::BufferDefinitionVector* buffers) { GetPointsToSet(instruction) - .ForEachElement([this, buffers, instruction]( + .ForEachElement([buffers, instruction]( const ShapeIndex& index, const PointsToSet::BufferList& source_buffers) { // Add buffers which 'instruction' is the source of. @@ -547,7 +546,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction); const PointsToSet& src_points_to_set = GetPointsToSet(src); dst_points_to_set.ForEachMutableElement( - [this, &dst_points_to_set, &src_points_to_set]( + [&dst_points_to_set, &src_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { *buffers = src_points_to_set.element(index); for (auto& tuple_source : src_points_to_set.tuple_sources(index)) { @@ -718,6 +717,7 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( // root at operand 0 or 1. Or... // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index // 0. +// (5) The 'user' of 'operand' is Sort, and it is the only user. // // (2) and (3) can only be determined if points-to analysis is available. bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( @@ -783,6 +783,21 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // TODO(b/62548313): Remove when buffer assignment is module scoped and // does not assign buffers to calls. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 0ac8df42714a1550d36560cbff901f6a8a4b3a8d..10d382e8abc92145c1804cbf18bbed714fa34571 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1012,6 +1012,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto sort = + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + Shape values_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + auto sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The buffer for the keys can be shared with the first tuple entry. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0})); + // The buffer for the values can be shared with the second tuple entry. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(values, {}, + sort, {1})); + // Verify that the buffers are not shared with the "wrong" tuple entry. + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(values, {}, + sort, {0})); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); @@ -1076,7 +1118,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto make_cond = [this, &data_shape]() { + auto make_cond = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -1085,7 +1127,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - auto make_body = [this, &data_shape]() { + auto make_body = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Body"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..af2cb6dc2a3f4a004351acc62796e0daf46719c2 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { + +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +// Finds and returns the non-constant operand in instr. +// +// CHECK-fails if instr doesn't have exactly one unique non-constant operand. +static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { + const HloInstruction* result = nullptr; + for (const HloInstruction* operand : instr->operands()) { + if (!operand->IsConstant()) { + if (result != nullptr) { + CHECK_EQ(result, operand); + } + result = operand; + } + } + CHECK_NE(result, nullptr); + return result; +} + +// If all of instr's operands are either constants or have the form +// get-tuple-element(gte_operand, N) +// for the same value N, returns N. Otherwise, returns nullopt. +static optional GetGTEOperandIndex(const HloInstruction* instr, + const HloInstruction* gte_operand) { + VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " + << gte_operand->ToString() << ")"; + optional tuple_idx; + for (const HloInstruction* operand : instr->operands()) { + if (operand->IsConstant()) { + continue; + } + // Look through copies. + // TODO(b/68830972): We wouldn't need this if for loop matching on the GPU + // would run before copy insertion. + if (operand->opcode() == HloOpcode::kCopy) { + operand = operand->operand(0); + } + if (operand->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "instr uses something other than gte(gte_operand): " + << operand->ToString(); + return nullopt; + } + if (operand->operand(0) != gte_operand) { + VLOG(2) << "instr has gte whose operand is not gte_operand: " + << operand->ToString(); + return nullopt; + } + if (tuple_idx && tuple_idx != operand->tuple_index()) { + VLOG(2) << "instr has operands with conflicting gte indices, " + << *tuple_idx << " vs " << operand->tuple_index(); + return nullopt; + } + + tuple_idx = operand->tuple_index(); + } + return tuple_idx; +} + +// Tries to get the tuple index of the induction variable of a while loop. +// +// Checks that the loop condition and root both plumb the induction variable +// through the same tuple index, and that they both apply exactly one op to the +// induction variable before deciding whether to do another loop iteration (in +// the loop condition's case) or packing the induction variable into the result +// tuple (in the loop body's case). +// +// Specifically, checks that the loop condition has structure +// +// root = op(constants, get-tuple-elem(param0, N), constants) +// +// and the loop body has the structure +// +// inc = op(constants, get-tuple-elem(param0, N), constants) +// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). +// +// If so, returns N. Otherwise, returns nullopt. +static optional GetLoopInductionVarTupleIdx( + const HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Finding induction variable for loop " + << while_op->ToShortString(); + + // The while_cond computation should have the form + // + // while_cond_root = + // op(constants, get-tuple-elem(while_cond_param, N), constants). + // + // If it does, set indvar_tuple_idx to N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_param = while_cond->parameter_instruction(0); + optional indvar_tuple_idx = + GetGTEOperandIndex(while_cond_root, while_cond_param); + if (!indvar_tuple_idx) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // The while_body computation should have the form + // + // while_body_inc = + // op(constants, get-tuple-elem(while_body_param, N), constants) + // while_body_root = tuple(..., while_body_inc, ...) + // + // where while_body_inc is operand N of while_body_root. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); + auto* while_body_param = while_body->parameter_instruction(0); + optional while_body_indvar_tuple_idx = + GetGTEOperandIndex(while_body_inc, while_body_param); + if (!while_body_indvar_tuple_idx) { + VLOG(2) + << "Induction variable not found in while body increment instruction: " + << while_body_inc->ToString(); + return nullopt; + } + if (while_body_indvar_tuple_idx != indvar_tuple_idx) { + VLOG(2) << "Tuple index of induction variable does not match between loop " + "condition (" + << *indvar_tuple_idx << ") and while body (" + << *while_body_indvar_tuple_idx << ")"; + return nullopt; + } + + // Finally, check that the while loop's initial value is a tuple with enough + // elements. + auto* while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); + return nullopt; + } + + VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; + return indvar_tuple_idx; +} + +optional ComputeWhileLoopTripCount(HloInstruction* while_op, + int64 max_value_returned) { + VLOG(2) << "Getting trip count for loop " << while_op->ToString(); + + // The loop's induction variable is found at + // + // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), + // + // where comp is while_op->while_body() or while_op->while_condition(). + optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx) { + return nullopt; + } + + // Now that we know the index of the induction variable, we can we can try to + // compute how many times the loop executes. Start by computing the induction + // variable's initial value. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); + StatusOr> indvar_init_result = + evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init: " + << indvar_init_result.status(); + return nullopt; + } + + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(*indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + + // The initial value of the induction variable. + std::unique_ptr indvar_iter_val = + std::move(indvar_init_result).ValueOrDie(); + for (int64 trip_count = 0; trip_count != max_value_returned + 1; + ++trip_count) { + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + StatusOr> result = + evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + if (!result.ok()) { + VLOG(2) << "Couldn't evaluate while cond: " << result.status(); + return nullopt; + } + if (result.ValueOrDie()->data() == + tensorflow::gtl::ArraySlice{false}) { + VLOG(2) << "Loop has static trip count of " << trip_count; + return trip_count; + } + + // Calculate the value of the induction variable after one iteration of the + // loop, and check whether the while condition is true with this new value. + StatusOr> indvar_next_result = + evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, + {{while_body_indvar, indvar_iter_val.get()}}); + if (!indvar_next_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable update: " + << indvar_next_result.status(); + return nullopt; + } + indvar_iter_val = std::move(indvar_next_result).ValueOrDie(); + } + + VLOG(2) << "Loop has unknown trip count."; + return nullopt; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..bf59813e8c405a8709446bf8457729348ceae4ec --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// Returns the precise trip count of the loop if it's statically known, +// nullopt otherwise. max_value_returned limits the number of steps that are +// evaluated while trying to brute force a loop trip count, trip counts larger +// than max_value_returned result in nullopt. +tensorflow::gtl::optional ComputeWhileLoopTripCount( + HloInstruction *while_op, int64 max_value_returned = 128); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index ec05a74e286c89dd8db5ae07580e461938d7c087..dd8697e680c56165f87c365a721eda2de1ebc085 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -26,23 +26,6 @@ namespace xla { using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; -// Finds and returns the non-constant operand in instr. -// -// CHECK-fails if instr doesn't have exactly one unique non-constant operand. -static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { - const HloInstruction* result = nullptr; - for (const HloInstruction* operand : instr->operands()) { - if (!operand->IsConstant()) { - if (result != nullptr) { - CHECK_EQ(result, operand); - } - result = operand; - } - } - CHECK_NE(result, nullptr); - return result; -} - // Determines whether the given instruction is a send/recv node, or has a // subcomputation which contains a send/recv node. static bool IsOrContainsSendOrRecv(const HloInstruction* instr); @@ -72,211 +55,6 @@ static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { return false; } -// If all of instr's operands are either constants or have the form -// get-tuple-element(gte_operand, N) -// for the same value N, returns N. Otherwise, returns nullopt. -static optional GetGTEOperandIndex(const HloInstruction* instr, - const HloInstruction* gte_operand) { - VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " - << gte_operand->ToString() << ")"; - optional tuple_idx; - for (const HloInstruction* operand : instr->operands()) { - if (operand->IsConstant()) { - continue; - } - if (operand->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "instr uses something other than gte(gte_operand): " - << operand->ToString(); - return nullopt; - } - if (operand->operand(0) != gte_operand) { - VLOG(2) << "instr has gte whose operand is not gte_operand: " - << operand->ToString(); - return nullopt; - } - if (tuple_idx && tuple_idx != operand->tuple_index()) { - VLOG(2) << "instr has operands with conflicting gte indices, " - << *tuple_idx << " vs " << operand->tuple_index(); - return nullopt; - } - - tuple_idx = operand->tuple_index(); - } - return tuple_idx; -} - -// Tries to get the tuple index of the induction variable of a while loop. -// -// Checks that the loop condition and root both plumb the induction variable -// through the same tuple index, and that they both apply exactly one op to the -// induction variable before deciding whether to do another loop iteration (in -// the loop condition's case) or packing the induction variable into the result -// tuple (in the loop body's case). -// -// Specifically, checks that the loop condition has structure -// -// root = op(constants, get-tuple-elem(param0, N), constants) -// -// and the loop body has the structure -// -// inc = op(constants, get-tuple-elem(param0, N), constants) -// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). -// -// If so, returns N. Otherwise, returns nullopt. -static optional GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Finding induction variable for loop " - << while_op->ToShortString(); - - // The while_cond computation should have the form - // - // while_cond_root = - // op(constants, get-tuple-elem(while_cond_param, N), constants). - // - // If it does, set indvar_tuple_idx to N. - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_param = while_cond->parameter_instruction(0); - optional indvar_tuple_idx = - GetGTEOperandIndex(while_cond_root, while_cond_param); - if (!indvar_tuple_idx) { - VLOG(2) << "Induction variable not found in loop condition: " - << while_cond->root_instruction()->ToString(); - return nullopt; - } - - // The while_body computation should have the form - // - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // while_body_root = tuple(..., while_body_inc, ...) - // - // where while_body_inc is operand N of while_body_root. - auto* while_body = while_op->while_body(); - auto* while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple instruction: " - << while_body_root->ToString(); - return nullopt; - } - - auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); - auto* while_body_param = while_body->parameter_instruction(0); - optional while_body_indvar_tuple_idx = - GetGTEOperandIndex(while_body_inc, while_body_param); - if (!while_body_indvar_tuple_idx) { - VLOG(2) - << "Induction variable not found in while body increment instruction: " - << while_body_inc->ToString(); - return nullopt; - } - if (while_body_indvar_tuple_idx != indvar_tuple_idx) { - VLOG(2) << "Tuple index of induction variable does not match between loop " - "condition (" - << *indvar_tuple_idx << ") and while body (" - << *while_body_indvar_tuple_idx << ")"; - return nullopt; - } - - // Finally, check that the while loop's initial value is a tuple with enough - // elements. - auto* while_init = while_op->operand(0); - if (while_init->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); - return nullopt; - } - - VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; - return indvar_tuple_idx; -} - -// Tries to determine the number of times the given loop executes. Currently -// simply returns 0, 1, or "can't tell" (nullopt). -static optional GetLoopTripCount(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Getting trip count for loop " << while_op->ToString(); - - // The loop's induction variable is found at - // - // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), - // - // where comp is while_op->while_body() or while_op->while_condition(). - optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx) { - return nullopt; - } - - VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx - << " in input tuple."; - - // Now that we know the index of the induction variable, we can we can try to - // compute how many times the loop executes. Start by computing the induction - // variable's initial value. - HloEvaluator evaluator(/*max_loop_iterations=*/0); - auto* while_init = while_op->mutable_operand(0); - auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init: " - << indvar_init_result.status(); - return nullopt; - } - - // Evaluates the while loop's condition, returning either "true" (continue - // looping), "false" (stop looping), or nullopt (can't evaluate). - auto evaluate_while_cond = [&](const Literal& indvar) -> optional { - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions(while_cond_root, - {{while_cond_indvar, &indvar}}); - if (!result.ok()) { - VLOG(2) << "Couldn't evaluate while cond: " << result.status(); - return nullopt; - } - return result.ValueOrDie()->data() == - tensorflow::gtl::ArraySlice{true}; - }; - - // The initial value of the induction variable. - const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); - - // Evaluate whether the while condition is true when seeded with - // indvar_iter0_val. - optional while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); - if (while_cond_iter0_val == false) { - VLOG(2) << "Loop has static trip count of 0."; - return 0; - } - - // Calculate the value of the induction variable after one iteration of the - // loop, and check whether the while condition is true with this new value. - auto* while_body = while_op->while_body(); - auto* while_body_indvar_update = - while_body->root_instruction()->operand(*indvar_tuple_idx); - auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); - StatusOr> indvar_iter1_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); - if (!indvar_iter1_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable update: " - << indvar_iter1_result.status(); - return nullopt; - } - const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); - optional while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); - if (while_cond_iter1_val == false) { - VLOG(2) << "Determined that loop has static trip count of 1."; - return 1; - } - - VLOG(2) << "Loop has unknown trip count >= 1."; - return nullopt; -} - // Tries to remove elements in a while loop's tuple that aren't used within the // loop. // @@ -577,7 +355,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { } // Remove while loops with static trip count of 0. - optional trip_count = GetLoopTripCount(while_op); + optional trip_count = + ComputeWhileLoopTripCount(while_op, + /*max_value_returned=*/1); if (trip_count && *trip_count == 0) { // The loop never executes, so the value of the loop is the value of its // "init" operand. diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 4391078b6484f25ba81aefa2c1d1f69d7d2774f4..c4c958be4a18f23b8e34f9e619e447c6bf4334b5 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -172,7 +172,7 @@ TEST_F(ShapeTreeTest, TupleShape) { // Write zero to all data elements. shape_tree.ForEachMutableElement( - [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; }); + [](const ShapeIndex& /*index*/, int* data) { *data = 0; }); EXPECT_EQ(0, shape_tree.element({})); EXPECT_EQ(0, shape_tree.element({0})); EXPECT_EQ(0, shape_tree.element({1})); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ec901af1e2057449452c4c65243593b016a26f61..34869cc5078699603c006387161fddd4fee4a9f8 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -596,8 +596,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { }; auto comma_list_to_int64s = - [&s, - string_to_int64](const string& input) -> StatusOr> { + [string_to_int64](const string& input) -> StatusOr> { std::vector results; for (const string& piece : tensorflow::str_util::Split(input, ',')) { TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); @@ -792,7 +791,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (LayoutUtil::IsSparseArray(shape)) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); tensorflow::gtl::ArraySlice padded_dimensions = LayoutUtil::PaddedDimensions(shape); if (!padded_dimensions.empty()) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e840067056868319fdc1ce8b4e9d0094565e05d3..42d52aee780e2aade0f2ed3597e653567b8da49b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -154,8 +154,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", @@ -192,8 +192,8 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -262,7 +262,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -290,8 +290,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -314,8 +314,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -334,8 +334,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -356,9 +356,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -376,9 +376,10 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -395,8 +396,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -419,9 +420,9 @@ xla_test( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -445,8 +446,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -464,9 +465,9 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -483,8 +484,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -501,8 +502,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -519,9 +520,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -543,8 +544,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -562,8 +563,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -586,8 +587,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -612,8 +613,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -638,7 +639,7 @@ xla_test( deps = [ ":client_library_test_base", ":literal_test_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -658,7 +659,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -681,8 +682,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -702,7 +703,7 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -725,8 +726,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -749,8 +750,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -773,8 +774,8 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -795,7 +796,7 @@ CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -838,8 +839,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -862,8 +863,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -891,10 +892,10 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -924,9 +925,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -950,8 +951,8 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -972,7 +973,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -991,8 +992,8 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1013,7 +1014,7 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -1044,8 +1045,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1065,9 +1066,9 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1096,9 +1097,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1123,9 +1124,9 @@ xla_test_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1164,9 +1165,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1187,7 +1188,7 @@ xla_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1241,8 +1242,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1260,7 +1261,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1283,8 +1284,8 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1306,8 +1307,8 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1327,8 +1328,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1345,8 +1346,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1362,8 +1363,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1387,8 +1388,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1408,8 +1409,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1438,8 +1439,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1458,7 +1459,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1481,9 +1482,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1507,8 +1508,8 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1527,8 +1528,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1549,8 +1550,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1572,7 +1573,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1593,8 +1594,8 @@ xla_test( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1612,8 +1613,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1635,8 +1636,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1656,8 +1657,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1673,8 +1674,8 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1687,8 +1688,8 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1708,8 +1709,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1793,8 +1794,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", @@ -1821,8 +1822,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", @@ -1858,8 +1859,8 @@ xla_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1886,8 +1887,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -1903,6 +1904,16 @@ xla_test( ], ) +xla_test( + name = "outfeed_in_nested_computation_test", + srcs = ["outfeed_in_nested_computation_test.cc"], + deps = [ + "//tensorflow/compiler/xla/tests:local_client_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "hlo_metadata_test", srcs = [ @@ -1912,7 +1923,7 @@ tf_cc_test( ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", @@ -1954,8 +1965,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1968,7 +1979,7 @@ xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], deps = [ - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2001,6 +2012,7 @@ xla_test( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -2053,8 +2065,8 @@ xla_test( ":local_client_test_base", ":test_utils", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -2075,7 +2087,7 @@ xla_test( ":client_library_test_base", ":literal_test_util", ":xla_internal_test_main", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", "//tensorflow/core:test", ], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 3ae96fa1bcb1057653a75db62def5556ae37f886..74f2e36f826cd82ce4015df857f3de67950beaeb 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 8d15b7841bc7298cd6865d8689cc496c0459e4b9..caeb0bf49a0dde9eeac02037b2ea04fd024d100c 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index 8c227df7f04e79ccc332062d0889d282c0f5e40f..af0b8522394a0c591e6c42ad12db8853ef66243c 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 6a024798f9e3faa5164b4dce6f43e517f6ab8eca..24b17b71007a1872462bed1f6b86ae1a5bb9922c 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -733,7 +733,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { var4D, [epsilon](float a) { return a + epsilon; }); auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( - var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); }); + var_add_epsilon, [](float a) { return 1 / std::sqrt(a); }); auto grad_output_times_var = *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 747c82b502c8ec9f8121641382d9fd3c9552b010..6c20f654fe3df6a28e9633cd832c11b487894bad 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 20cb989751ad69e2f3cf97c87c43293951f599ab..0d7a3aa46a9c12c19d954c11ae3a2cccbed886ef 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index d531e8fa82e47f7bcd278f10da2c205e44db0ac1..c6b5108fe9e5bcf843982676d822f1942359da71 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 50dd574624bb3874e682be5a272fb5bdefa4adc4..1d28e85b16596b0ec2717138fb2081878203e8b2 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 2086e38b91955b23ab11af73acd7faf46ca4bb18..b1d18210eaafdfec0920c0cccaa0dfdbd6de5609 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 0bc8facfe2cfcfab094f483137f6d8e241c6aaf9..a4eb57fc7b9abd460a7d158d0dc629eba88018cd 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 7a2e70d39f38ee32932c944b226fda2d1a9767f6..59d917054be2ebe3a25f902f51972a682a5231b6 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index f0f7ff1ea0cc6390691fc244b28bcb312b2c2e90..4a6e8a31241d39db21935576d57f0acb17caef11 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -26,8 +26,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 6ce2f844a34b01cd07df97a9bb12842490838be6..c898dacf489db97223e2918414daf5de88bece64 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index ff3824628676df8a37f3a98742d9423fd42928e4..7c52c9fbbb57f9291ea9f0966e2efa715819fb67 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 64bf8b3b3878f74e0557afc520c9ae342bd07c4a..5a06d061f0d83fff547502495ff8ab13fb421b70 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 9f288634c0fa7d7dffa7f9c1af3a0752996cdec2..be017477d84eb9faf5aa79dcdf54d6b6aaf6fd8e 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 369663de150964bd271b92d91dcd9276c87a763c..b27c1044baf2c0002f166c53a81e4361c60d012a 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 71d72a9828c5445be2cb1f559cf31363507bcd8d..49375748319ad5fe40db507a034ec4b07adb7e84 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 0fb6853e3f408488e79e52050f8b00c1ca073fef..1adc68cc4839dcd7d89741ec016f27bc9047c9a5 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 944366410b14439aa33999185525f1029735e95b..7b6bbc4f571af2e11306f95c24e243e78e0f4f4e 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index a8b8f74ca9603a71acefc0be2141d7b9caf2b73b..5ed8122e0073bde77bb2507a0ddd89c4365627c9 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 8792e7781b17465d94ae8ac8375a4523f368d720..6784c16715da72d337edf70fa51db42c59404136 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 1dc6ff0f4f51b51002cfb868a51457c08a259a80..5ef273e5a26ea8a16db864974c9bfa2c296cbce8 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 90f3d1b874f4da09104dc066c6642db1d2e77997..13c777835eb2d2519d39205cdc96f0aac4850c7d 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index d4b3aac85bff283515088f6e61c9d2bad11f60d3..5f234f36a8543ad408fb3430b27844beb16a54b5 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index a6a233e71aabc47c78ea291b71f8b831f1c60323..2db6503afab748d7b778e26b2f9350ac64c7778b 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 810947ab01b69b10b6ae60c551bd7aba10a6313d..3f3e8ab712fea14be9e4a7015effdf8ce518309b 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index d86fd7cc2d4da10ed726ca11a6d9f86287a5d11e..0e9e92ed996fbb34826d19b670c7c4920a1aad13 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -111,7 +111,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR1(&builder, {static_cast(2.0f)}); @@ -137,7 +137,7 @@ std::vector MinorToMajorForIsRowMajor(bool row_major) { return {row_major ? 1 : 0, row_major ? 0 : 1}; } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); @@ -148,7 +148,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); @@ -160,7 +160,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D( @@ -172,7 +172,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); @@ -183,7 +183,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto param0 = @@ -533,7 +533,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) { using T = TypeParam; XlaBuilder builder(this->TestName()); @@ -612,7 +612,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { using T = TypeParam; XlaBuilder builder(this->TestName()); @@ -648,7 +648,49 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(3); + dnums.add_rhs_contracting_dimensions(2); + dnums.add_lhs_batch_dimensions(0); + dnums.add_lhs_batch_dimensions(1); + dnums.add_rhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + + DotGeneral(x, y, dnums); + + auto x_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{9.0f, 10.0f}, {11.0f, 12.0f}}, + {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) + .ConsumeValueOrDie(); + + auto y_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, + {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR4( + &builder, + /*expected=*/ + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}}, + {x_data.get(), y_data.get()}, this->error_spec_); +} + +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { using T = TypeParam; for (bool transpose_lhs : {false, true}) { for (bool transpose_rhs : {false, true}) { @@ -708,7 +750,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstLHS) { using T = TypeParam; auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -754,7 +796,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstRHS) { using T = TypeParam; std::unique_ptr> constant_rhs_array( diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 88ac96d6b0f9206ef1ed0e4135495d7903ebf3f4..7f6f203a1ba48e0053f799c58bbbeae87aef1f7f 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index ebba13c5b39e241ff01c9ddcba8a1c04180b4bd0..5116e60ca63ef5f94b25b15e6616086fb9e44bbb 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 86bfaea4ef43ad382e497fd281ec5439f001b56f..bf1de02ba9dbd97db9ee31484402fe9b92385219 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 30dc639f117b9871238f0bf1628502cf8bef2e0c..39cc6c5927f1d416e31f689487efc10c20371abe 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index 0254ae1baaa864b38c3b217a5c2026d34b7f7d12..c5bbbe778df15d63a2586bd6291a7a33fc82aa52 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 607bcdd51ee8ff678cd84622a81f45d5525bb683..792be0d3fcd55621b9f8cdf0fdc28f7bb49294d1 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index c5ca64fa3f0fc70542c577828d53eeecbd05067b..b77bece85ad1b2192b04330af9e60d3a424b59f4 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 73a47eda721971c75f61109787844c40be0b7080..51450314b611b49c643fb6fd5b0c0d2e7205a2d2 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -48,7 +48,8 @@ class UnaryOpTest : public HalfTestBase, public ::testing::WithParamInterface {}; XLA_TEST_P(UnaryOpTest, Ops) { - std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); + std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1), half(9.0), + half(42.0), half(-9.0), half(-100.0)}); XlaBuilder builder(TestName()); XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index 4d82442f7e3630c115eff1f17544e2b892c5e7eb..5511190caf95544e2ac48d91c0a138db06a2544c 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index b662e837168c8b16daea0181786be19fa0237a8c..0dce1b22a331e20054ba026f23e3284d7dd0e88a 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -233,6 +233,29 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } +::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + const auto& fake_arguments = + MakeFakeArguments(module_or_status.ValueOrDie().get()) + .ConsumeValueOrDie(); + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + return test_runner_ + .Execute(std::move(module_or_status.ValueOrDie()), + fake_argument_ptrs, /*run_hlo_passes=*/true) + .ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure(); +} + ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 66719b1460063a61541535ff7507468ae0ca1ada..bb55e562ad5e8fc32771d1ac23579db17a7e819d 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -166,6 +166,8 @@ class HloTestBase : public ::testing::Test { const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; + ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string) + TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index f950aa1e8fe745075234a5ebff52d92be7378a5d..17ac95ae0198d98490b25f7f2edd32d1e0495803 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -34,7 +35,7 @@ class IotaTest : public ClientLibraryTestBase { } }; -TEST_F(IotaTest, SimpleR1) { +XLA_TEST_F(IotaTest, SimpleR1) { for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { { XlaBuilder builder(TestName() + "_f32"); diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 0df50150aee69749beea79ff522fb6f820d1945d..e2cd5bcc5a95f692dcf4a43d717252bfe876aa81 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 47cab796041e9669affaebd7866d0d80100730f1..115448c908ac9e7f0b01772ce348d23bf4d838ed 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -42,13 +42,12 @@ extern "C" void SumStructElements(float* out, void** parameters) { TEST_F(LocalClientAotTest, Constant) { xla::ExecutableRunOptions run_options; OpaqueData opaque_data{100, 20, 3}; - void* parameters[] = {&opaque_data}; float out = 0; - void* temporary_buffers[] = {nullptr, &out}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + void* temporary_buffers[] = {&opaque_data, &out}; + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 246.0f); opaque_data = {1, 2, 3}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 12.0f); } diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 70612e7c49d2815096cc54fd6ae796148249b4db..e310966d8b062f2baac00a17dd42cd449595d0d2 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,8 +21,8 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -92,9 +92,10 @@ int main(int argc, char** argv) { // It's lame to hard-code the buffer assignments, but we need // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); - CHECK_EQ(result->buffer_sizes().size(), 2); - CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer + CHECK_EQ(result->buffer_sizes().size(), 3); + CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer + CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer if (triple.isOSBinFormatELF()) { // Check the ELF magic. CHECK_EQ(result->object_file_data()[0], 0x7F); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 5c3498c84cb68e4b1c4a7814284418f1ebbc0e98..1a823cf189b310c62c735419936544ea99fcfbaf 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index c31ba0e713a45d18b60bfdb9a47545cf34220333..eaddf756dbc913dd9668cd22228fbd18c2c33309 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 258226523d830b40ecaa761df95988dc90f5ca47..b4477e9a6b23363ee3a1380f9f98f4b8226f6920 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index cdf70ee4185be2ecd9dcb2d21fbd98c2ab6cc0ad..2d622242e657ce032a17f7b26c94227d343e2a38 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 7ddc6369319810c0806afa161bc00f51caea2072..0732e195d44d738b264361e43d38259c26a4116e 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 069b8a881f4be0c05b19bb1f323bdc13c7222ceb..da8c42d465340f2af3d6acd2c3676b69512f193f 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index e576f000ef23e761d6fa818457eec2144d4bcb00..955dbef6dcd28421fb351c6ee064ac53eda1fd08 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a0426adcbc1b5b89be0841fa2c4204e2b65abf4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +// Tests that ensure outfeed instructions that are contained in nested +// computations in non-root positions are executed. + +class OutfeedInNestedComputationTest : public LocalClientTestBase {}; + +XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { + XlaBuilder b(TestName()); + + Shape state_tuple_array_shape = ShapeUtil::MakeShape(xla::S32, {10, 5}); + Shape int_shape = ShapeUtil::MakeShape(xla::S32, {}); + Shape state_tuple_shape = + ShapeUtil::MakeTupleShape({int_shape, state_tuple_array_shape}); + Shape xfeed_shape = ShapeUtil::MakeShape(xla::S32, {2}); + + XlaOp some_buffer = Broadcast(ConstantR0(&b, 0), {10, 5}); + XlaOp num_iter = Infeed(&b, int_shape); + XlaOp init_tuple = Tuple(&b, {num_iter, some_buffer}); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_cond, [&] { + // Condition: iteration variable > 0 + XlaBuilder cond_builder("loop_condition"); + XlaOp state_tuple = Parameter(&cond_builder, 0, state_tuple_shape, "state"); + XlaOp loop_counter = GetTupleElement(state_tuple, 0); + Outfeed(loop_counter, int_shape, ""); + Gt(loop_counter, ConstantR0(&cond_builder, 0)); + return cond_builder.Build(); + }()); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_body, [&] { + XlaBuilder body_builder("loop_body"); + XlaOp state_tuple = Parameter(&body_builder, 0, state_tuple_shape, "state"); + XlaOp loop_counter = GetTupleElement(state_tuple, 0); + XlaOp buffer_inside = GetTupleElement(state_tuple, 1); + + // Read some stuff from Infeed. + XlaOp some_input = Infeed(&body_builder, xfeed_shape); + XlaOp sum = Add(some_input, Broadcast(loop_counter, {2})); + Outfeed(sum, xfeed_shape, ""); + + XlaOp iter_left = Sub(loop_counter, ConstantR0(&body_builder, 1)); + + Tuple(&body_builder, {iter_left, buffer_inside}); + return body_builder.Build(); + }()); + + // Build loop. + XlaOp result_tuple = While(loop_cond, loop_body, init_tuple); + GetTupleElement(result_tuple, 0); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); + + std::unique_ptr comp_result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + comp_result = local_client_->ExecuteAndTransfer(computation, {}) + .ConsumeValueOrDie(); + })); + + VLOG(1) << "Transferring trip count to computation"; + // Transfer number of iterations to Infeed. + TF_ASSERT_OK( + local_client_->TransferToInfeed(*LiteralUtil::CreateR0(1))); + + // Pick up value from outfeed + { + VLOG(1) << "Reading from condition outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&int_shape)); + EXPECT_EQ(r->Get({}), 1); + } + + VLOG(1) << "Writing data to infeed"; + // Transfer some stuff to Infeed for use inside of loop. + TF_ASSERT_OK(local_client_->TransferToInfeed( + *LiteralUtil::CreateR1({10, 20}))); + + // Pick up value from outfeed + { + VLOG(1) << "Reading from body outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&xfeed_shape)); + EXPECT_EQ(r->Get({0}), 11); + EXPECT_EQ(r->Get({1}), 21); + } + + { + VLOG(1) << "Reading from condition outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&int_shape)); + EXPECT_EQ(r->Get({}), 0); + } + + // Joins the thread + thread.reset(); + + EXPECT_EQ(comp_result->Get({}), 0); +} + +XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { + XlaBuilder b(TestName()); + + Shape condition_shape = ShapeUtil::MakeShape(xla::PRED, {}); + Shape result_shape = ShapeUtil::MakeShape(xla::PRED, {}); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation true_computation, [&] { + XlaBuilder inner_builder("true_computation"); + XlaOp param = Parameter(&inner_builder, 0, result_shape, "param"); + Outfeed(param, result_shape, ""); + Or(param, param); + return inner_builder.Build(); + }()); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation false_computation, [&] { + XlaBuilder inner_builder("false_computation"); + Parameter(&inner_builder, 0, result_shape, "param"); + return inner_builder.Build(); + }()); + + XlaOp pred = Infeed(&b, condition_shape); + Conditional(/*predicate=*/pred, /*true_operand=*/pred, + /*true_computation=*/true_computation, /*false_operand=*/pred, + /*false_computation=*/false_computation); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); + + std::unique_ptr comp_result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + comp_result = local_client_->ExecuteAndTransfer(computation, {}) + .ConsumeValueOrDie(); + })); + + TF_ASSERT_OK( + local_client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&result_shape)); + + EXPECT_EQ(r->Get({}), true); + + // Join the thread + thread.reset(); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index e428fa9b5e14d0cb6e5610a1b69b07c6b0c9952a..ca21b0b2ba590a6daadf2c8d3d9ad213514b0f0f 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 8ba1d11b33344463ffcb059f453754f31e177184..f6c762e7a4bee91a26c4c2e033c3717fef6d91d0 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 5c351b2d113709105244de4aafa49d7cc535ced1..2fc7f816b56db6f57ca835d1847476b6d622ce5e 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 3f98099be60ee4694a75f3200f130e49ed39fe67..326e13b3867f2f804e882e00e35850d0189ad8d7 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -182,7 +182,7 @@ XLA_TEST_F(PrngTest, Uniformity256) { XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. - auto build_sum_rng = [this](XlaBuilder& builder) { + auto build_sum_rng = [](XlaBuilder& builder) { auto b = builder.CreateSubBuilder("sum_with_rng"); auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input"); Add(x, diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index 526a38e8d1dbed9cdd4a31bfbec49bc5c6bb174b..fab2a65de109c670a6854c0fc1118162acf3d312 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 04c7f316463441d1bd458393b29ea5eb2acb9c9b..531648fe3eb8e3941c5e3c012847ee68c616590f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index e4a8ddf86a46703c308c58495ae26605f51a5d71..2065271a7f686c52c88df80b0efe8f2e1542d198 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index c2681f70f7e5727462c20d5eb3120bd34fd75765..1bd6fdab31d6c3516339bdb98459ffe3bbdef1d1 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index d544968648d7602464bd141a12c75eeb8c1678da..d8914513819415368a628eab1f482f9644dd46b1 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 7c0389cfa3251a6b62f83a78e986d870177d4d91..368f5583c9ce3773e57b858ff7606f679346529a 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 46d91711a55c5aa0ad906bc9ba0265fab194cf1a..382d1b1ae741285dcd1f7761edb82a5c333887af 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 23f0d26d93bf979970d112993c0a945fb4fe7d53..41e49b4003236d55d85592315652a0ddefd5c485 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 3b603c0d31565c20ff4e43c3ffd6001e9d54612e..e42c71eb284deb2e50d6ea4b47fa707e4bc14ffc 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index b1f1e69d3cdb9398a2ebc929a344a69adc6c2223..e3d4f98dd7432d1dce7e697586e8b17105dc82e7 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 59409ab26e1c19a8271318c18e19caa7b8ddc3b7..1c01402798658877889527a5dd02d5c74787ff99 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index a593faca0035b64670f294f81fd5b6d95f35cd88..b8ad6668f80a3002eff3cc458997966ee67c8d4b 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 8f424ae81f592bfd8accd8decb8fc363f7561c73..a2f0338e25977d7c76dbc48b3afc649b77ba4ee2 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 0f86b7f20f9bd7597ece713626ee0e9c23509e05..125513ddfd16cb4e742e7d589e22b721307621ee 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -60,7 +61,7 @@ class TransferManagerTest : public LocalClientTestBase { } protected: - Backend::StreamPtr stream_ptr_; + StreamPool::Ptr stream_ptr_; se::Stream* stream_; private: diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index 6ebb4324f8d20ed9f8886d92b0513441685ed19b..fbe9d1b64aa0c06d65b547c45cfa981800d40ff3 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index a517007591eac5aa7a190b8dd36bc2dd1b4b7ded..97bbf80aff80e995ea5cdd3e5d8807ee4d380067 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -586,9 +586,9 @@ XLA_TEST_F(TupleHloTest, })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); - auto literal = MakeUnique(); + auto literal = Literal::CreateFromShape(expected->shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), literal.get())); + backend().default_stream_executor(), expected->shape(), *literal)); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); } diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index a90a6fb0a5b5bb5119eee93c9c6a1377e3461b46..20ae68ab74026936c43e5f525eb796eb402a19cb 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index ea3aba6df1d3fbd492a23b280309322b8524c0bf..ef1b1445bbe555da00db4446d59439b752735a80 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 79bae22dac9599a38c73ea1dc2e6b4856395ff79..3848ec1684cdc9186e14ac0b60315b7520d127f3 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 29befef92e44c4f7673e0c7153efad31d2bbc2b1..1bdf1867b9330b715b0ba4aca71d56307883c775 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -1236,6 +1236,35 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {param_value.get()}, ErrorSpec(4e-5)); } +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { + auto while_shape = ShapeUtil::MakeShape(S32, {}); + + XlaComputation condition; + { + XlaBuilder builder("condition"); + Parameter(&builder, 0, while_shape, "state"); + Infeed(&builder, ShapeUtil::MakeShape(PRED, {})); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + XlaComputation body; + { + XlaBuilder builder("body"); + auto indvar = Parameter(&builder, 0, while_shape, "state"); + Add(indvar, ConstantR0(&builder, 1)); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + XlaBuilder builder(TestName()); + While(condition, body, ConstantR0(&builder, 0)); + + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + + ComputeAndCompareR0(&builder, 2, {}); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index c000ff4dc8e34f22dfe0173d3876d0a10ef77c81..11f3efb1f34ad23ebdcbb65c90aa5fb7a6adeae5 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -18,10 +18,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -83,8 +84,8 @@ Status ParseOneProfileOutputLine( tensorflow::gtl::ArraySlice opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; - string match_percentage = "\\d+\\.\\d\\d%"; - string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; + string match_percentage = R"(\d+\.\d*% +\d+Σ)"; + string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; string match_usecs = "([0-9.]+) usec"; string match_flops = "([^ ]*)"; string match_trops = "([^ ]*)"; @@ -133,7 +134,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, DeviceMemoryAllocator* allocator = backend->memory_allocator(); auto* transfer_manager = backend->transfer_manager(); TF_ASSERT_OK_AND_ASSIGN( - Backend::StreamPtr stream_ptr, + StreamPool::Ptr stream_ptr, backend->BorrowStream(backend->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN( @@ -224,7 +225,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { MaybeFind(parsed_profile_lines, "tanh")); EXPECT_GT(total_profile.cycles, 0); - EXPECT_EQ(total_profile.cycles_percentage, "100.00%"); + EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ"); EXPECT_TRUE(HasFlops(total_profile)); EXPECT_TRUE(HasTrops(total_profile)); @@ -332,7 +333,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { EXPECT_GT(total_while_body_profile.cycles, 0); EXPECT_EQ(total_while_body_profile.opcode, "[total]"); - EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); + EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ"); EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles); EXPECT_NE(multiply_profile.cycles_percentage, "0.00%"); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 55501827f29582324ce3308f2a7d96bc20b65760..40d28a57bfddd3403cad8252df985b746362631f 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -37,6 +37,7 @@ cc_library( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", @@ -84,7 +85,9 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", @@ -164,6 +167,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", @@ -181,6 +185,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto", @@ -198,6 +203,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index befb55453777dce30af89bcaad2ffe1647097576..f20dcef382b86d27d7c176ae7e4132ad1db7b901 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index cfb8f37487d6499b803438a135be54524fcf17d2..f0af0580c1fbca455c6ed5f87f82971faee50a06 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 5dd5150be339846d0775880931f615b92c5b08d8..f03e1b1f965af761c101555fd0275bc0425b9cf0 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index a5dce20456c6a2402f425ebb3d575d1bb625f839..dc5c106d02cb679f3e6f5b2bea40bbb42f8bd1cc 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 854e797ec2e31d32d98f46a75c31ff89caac613b..be4cf4318b33f41fc611ea90a1a02198e23b84e4 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -30,6 +30,9 @@ limitations under the License. // The output format is: // // file_path: computation_name :: type:literal_str +// +// Note: If you pass multiple modules, they will be compiled in parallel but run +// in series. #include #include @@ -42,7 +45,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -74,6 +79,18 @@ struct Options { int num_runs = 1; }; +std::unique_ptr CompileExecutable(const HloSnapshot& module, + LocalClient* client) { + XlaComputation computation(module.hlo().hlo_module()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); + } + return client + ->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); +} + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // @@ -84,6 +101,7 @@ struct Options { // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, // no infeed is performed. StatusOr ReplayComputation(const HloSnapshot& module, + LocalExecutable* executable, LocalClient* client, const Options& opts) { XlaComputation computation(module.hlo().hlo_module()); @@ -166,34 +184,34 @@ StatusOr ReplayComputation(const HloSnapshot& module, }); } - std::vector argument_layouts; - for (const auto& param : computation.proto().program_shape().parameters()) { - argument_layouts.push_back(¶m); - } - std::unique_ptr executable = - client->Compile(computation, argument_layouts, ExecutableBuildOptions()) - .ValueOrDie(); - - // Do not attmept to run the executable, if num_runs is less than 1. + // Do not attempt to run the executable if num_runs is less than 1. if (opts.num_runs < 1) { return Cancelled("Cancelled after compilation since --num_runs < 1."); } // Run the computation num_runs times, and return the result from the last // execution. + const bool xla_hlo_profile = + legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { + // If xla_hlo_profile is enabled, print a noisy message before the last run, + // making it easier to separate this profile from the others in the logspam. + if (xla_hlo_profile && i == opts.num_runs - 1) { + LOG(INFO) << "\n\n***** Final run below ******"; + } ExecutionProfile profile; ExecutableRunOptions run_options; run_options.set_execution_profile(&profile); run_options.set_allocator(&allocator); TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); - LOG(INFO) << "Execution took " - << static_cast(profile.compute_time_ns()) / 1e9 << "s"; + LOG(INFO) << "Done executing in " + << static_cast(profile.compute_time_ns()) / 1e9 + << "s: " << module.hlo().hlo_module().name(); } TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, @@ -234,15 +252,39 @@ StatusOr ParseInputFile(const string& filename, int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; + + std::vector snapshots; for (char* arg : args) { StatusOr maybe_snapshot = ParseInputFile(arg, opts); - if (!maybe_snapshot.ok()) { - continue; + if (maybe_snapshot.ok()) { + snapshots.push_back(std::move(maybe_snapshot).ValueOrDie()); } - HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); - StatusOr result_status = ReplayComputation(snapshot, client, opts); + } + + // Compile all the modules in parallel. + LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; + std::vector> executables; + { + // ThreadPool CHECK-fails if we give it 0 threads. + tensorflow::thread::ThreadPool thread_pool( + tensorflow::Env::Default(), tensorflow::ThreadOptions(), + "compile_modules", std::max(size_t{1}, snapshots.size()), + /*low_latency_hint=*/false); + executables.resize(snapshots.size()); + for (int64 i = 0; i < snapshots.size(); ++i) { + thread_pool.Schedule([&snapshots, &executables, client, i] { + executables[i] = CompileExecutable(snapshots[i], client); + }); + } + } + LOG(INFO) << "Done compiling; now running the modules."; + + for (int64 i = 0; i < executables.size(); ++i) { + LocalExecutable* executable = executables[i].get(); + StatusOr result_status = + ReplayComputation(snapshots[i], executable, client, opts); if (!result_status.ok()) { - fprintf(stderr, "%s: error: %s\n", arg, + fprintf(stderr, "%s: error: %s\n", args[i], result_status.status().ToString().c_str()); exit_status = EXIT_FAILURE; continue; @@ -250,10 +292,11 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { if (opts.print_result) { Literal result = std::move(result_status).ValueOrDie(); - fprintf(stdout, "%s: %s :: %s:%s\n", arg, - snapshot.hlo().hlo_module().name().c_str(), + fprintf(stdout, "%s: %s :: %s:%s\n", args[i], + executable->executable()->module().name().c_str(), ShapeUtil::HumanString(result.shape()).c_str(), result.ToString().c_str()); + auto& snapshot = snapshots[i]; if (snapshot.has_result()) { std::unique_ptr literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 0b300dc7b2d03cc8e1564f78412cc610cff518cd..4c35e93d38450b8263290da8e327d1f2126c1532 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -447,6 +447,20 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; +} + message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -547,3 +561,11 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some op (e.g., all-to-all). + repeated int64 replica_ids = 1; +} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index a173c51879f415ca98ad67015338ce5ab5e357a5..cc34db995e2ad653c6acce10d451de62ae8b264b 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -107,7 +107,6 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/timeseries", "//tensorflow/contrib/tpu", - "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", @@ -130,8 +129,12 @@ py_library( "//tensorflow/contrib/bigtable", # depends on bigtable "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", - "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code + # TODO(aaroey): tensorrt dependency has to appear before tflite so the + # build can resolve its flatbuffers symbols within the tensorrt library. + # This is an issue with the tensorrt static library and will be fixed by + # the next tensorrt release, so fix the order here after that. "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows + "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code ]), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index ded05da71877566781a5fb6d0c21e1c8d43de9ed..e18ea8df4df719a7317333cf9038ce7facf8d6ac 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function import os # Add projects here, they will show up under tf.contrib. +from tensorflow.contrib import autograph from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import checkpoint diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 7cbba7168383f3d0cdc80fda9908cb7d70836bb4..2d2ab7040a8bb76f9538f201f75a2e4dcba0f511 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -204,6 +204,7 @@ py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":converters", "//tensorflow/contrib/autograph/core:test_lib", diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py index 9c58ae3accaf7e3bca91750ecaf07845fafdbbae..38faba45df6746d56933a1647594af133b671628 100644 --- a/tensorflow/contrib/autograph/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -35,7 +35,7 @@ class AssertsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = asserts.transform(node, ctx) - self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call)) + self.assertTrue(isinstance(node.body[0].value, gast.Call)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 2a60750bdae273ca349c305b033313fa61f41872..180779670d91abd7d395bda0b63f592967c5015b 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -42,7 +42,7 @@ class BreakTransformer(converter.Base): var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ - var_name = True + var_name = tf.constant(True) continue """ return templates.replace(template, var_name=var_name) @@ -85,7 +85,7 @@ class BreakTransformer(converter.Base): guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ - var_name = False + var_name = tf.constant(False) while test and not var_name: body else: @@ -122,7 +122,7 @@ class BreakTransformer(converter.Base): # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ - var_name = False + var_name = tf.constant(False) for target in iter_: (var_name,) body diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index c26ca2946ce40e30248d1d835bbe6517911540c0..fcae7d68c0f90817e001b45fa86ca6be08456027 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -20,13 +20,16 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import break_statements from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class BreakCanonicalizationTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, break_statements, {}) as result: + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_while_loop(self): @@ -40,9 +43,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 4) def test_for_loop(self): @@ -55,7 +59,8 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - with self.converted(test_fn, break_statements, {}) as result: + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: # The break is incompletely canonicalized. The loop will not interrupt, # but the section following the break will be skipped. self.assertEqual([3], result.test_fn([5, 4])) @@ -77,9 +82,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 11) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 11) def test_nested_loops(self): @@ -99,10 +105,11 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 5) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 5) def test_loop_orelse(self): @@ -120,9 +127,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 3) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index a36b3d77a9233daed864c616306b2ad27f582a38..2d1bed3367fa0b283200b775c5953da80c855367 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -238,7 +238,7 @@ class CallTreeTransformer(converter.Base): # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ - ag__.converted_call(func, True, False, {}, args) + ag__.converted_call(func, True, False, False, {}, args) """ call_expr = templates.replace(template, func=node.func, args=node.args) new_call = call_expr[0].value diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 958bde0a58764e705c35ab73ce879b2c11ce7cdc..0476e97c15e33dcfc09b3555cf8dc7ff3fd7ce19 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -37,7 +37,7 @@ class ContinueCanonicalizationTransformer(converter.Base): def visit_Continue(self, node): self.set_local(CONTINUE_USED, True) template = """ - var_name = True + var_name = tf.constant(True) """ return templates.replace( template, var_name=self.get_local(CONTROL_VAR_NAME)) @@ -92,7 +92,7 @@ class ContinueCanonicalizationTransformer(converter.Base): if self.get_local(CONTINUE_USED, False): template = """ - var_name = False + var_name = tf.constant(False) """ control_var_init = templates.replace(template, var_name=continue_var) nodes = control_var_init + nodes diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index 3a7c7d1486de81482a191a321547ec1e67bf8618..37c15211b4fe266e57879249fe7e060ded44dc1f 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -20,13 +20,16 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class ContinueCanonicalizationTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, continue_statements, {}) as result: + with self.converted(test_fn, continue_statements, {}, + constant_op.constant) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_basic(self): @@ -40,10 +43,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) def test_for_loop(self): @@ -56,10 +60,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1]) - self.assertTransformedEquivalent(test_fn, [2]) - self.assertTransformedEquivalent(test_fn, [1, 2, 3]) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1]) + self.assertTransformedEquivalent(test_fn, [2]) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) def test_nested(self): @@ -78,10 +83,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/contrib/autograph/converters/directives.py index ccdf79d47be65dd777a7ae3a226246a62e274430..77f625bac792621c45799d1a220f99eb4b99f7af 100644 --- a/tensorflow/contrib/autograph/converters/directives.py +++ b/tensorflow/contrib/autograph/converters/directives.py @@ -42,10 +42,30 @@ def _map_args(call_node, function): Returns: Dict[Text, ast.AST], mapping each of the function's argument names to the respective AST node. + Raises: + ValueError: if the default arguments are not correctly set """ args = call_node.args kwds = {kwd.arg: kwd.value for kwd in call_node.keywords} - return tf_inspect.getcallargs(function, *args, **kwds) + call_args = tf_inspect.getcallargs(function, *args, **kwds) + + # Keyword arguments not specified in kwds will be mapped to their defaults, + # which are Python values. Since we don't currently have a way to transform + # those into AST references, we simply remove them. By convention, directives + # use UNSPECIFIED as default value for for optional arguments. No other + # defaults should be present. + unexpected_defaults = [] + for k in call_args: + if (k not in kwds + and call_args[k] not in args + and call_args[k] is not directives.UNSPECIFIED): + unexpected_defaults.append(k) + if unexpected_defaults: + raise ValueError('Unexpected keyword argument values, %s, for function %s' + % (zip(unexpected_defaults, + [call_args[k] for k in unexpected_defaults]), + function)) + return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED} class DirectivesTransformer(converter.Base): diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py index 5f798a5b76dc42b0553cf8bffcf2c6227aabc67c..a2d083b891314d2f8f3fa61b46edc347ca8e24eb 100644 --- a/tensorflow/contrib/autograph/converters/directives_test.py +++ b/tensorflow/contrib/autograph/converters/directives_test.py @@ -23,6 +23,7 @@ from tensorflow.contrib.autograph.core import converter_testing from tensorflow.contrib.autograph.core.converter import AgAnno from tensorflow.contrib.autograph.lang import directives from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test @@ -38,7 +39,7 @@ class DirectivesTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) - def_, = anno.getanno(node.body[0].body[0].targets[0], + def_, = anno.getanno(node.body[0].targets[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] self.assertEqual(d['dtype'].s, 'a') @@ -52,7 +53,7 @@ class DirectivesTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) - def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] self.assertEqual(d['dtype'].n, 1) self.assertEqual(d['shape'].n, 2) @@ -67,11 +68,27 @@ class DirectivesTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) - d = anno.getanno(node.body[0].body[1], AgAnno.DIRECTIVES) + d = anno.getanno(node.body[1], AgAnno.DIRECTIVES) d = d[directives.set_loop_options] self.assertEqual(d['parallel_iterations'].n, 10) self.assertEqual(d['back_prop'].id, 'a') - self.assertEqual(d['swap_memory'], directives.UNSPECIFIED) + self.assertNotIn('swap_memory', d) + + def test_invalid_default(self): + + def invalid_directive(valid_arg, invalid_default=object()): + del valid_arg + del invalid_default + return + + def call_invalid_directive(): + invalid_directive(1) + + node, _ = parser.parse_entity(call_invalid_directive) + # Find the call to the invalid directive + node = node.body[0].body[0].value + with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): + directives_converter._map_args(node, invalid_directive) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/contrib/autograph/converters/error_handlers.py index 3f2366215268cffe1aa2c55a174dbdba6127d701..193682139438c1d0133b17165d7f7fb84e2eaaac 100644 --- a/tensorflow/contrib/autograph/converters/error_handlers.py +++ b/tensorflow/contrib/autograph/converters/error_handlers.py @@ -37,7 +37,8 @@ class ErrorRewritingTransformer(converter.Base): def visit_FunctionDef(self, node): node = self.generic_visit(node) - if anno.hasanno(node, anno.Basic.ORIGIN): + if (anno.hasanno(node, anno.Basic.ORIGIN) and + len(self.enclosing_entities) <= 1): template = """ try: body diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py index 878526c8b4825080af900a9e838f2d557e8d5273..5d61b220afa0fcf9a9e619bbd78f83a5076c473a 100644 --- a/tensorflow/contrib/autograph/converters/error_handlers_test.py +++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py @@ -34,11 +34,15 @@ class ErrorHandlersTest(converter_testing.TestCase): raise ValueError() node, ctx = self.prepare(test_fn, {}) - anno.setanno(node.body[0], anno.Basic.ORIGIN, - origin_info.OriginInfo('test_path', None, None, None, None)) + anno.setanno( + node, anno.Basic.ORIGIN, + origin_info.OriginInfo(None, 'test_function_name', 'test_code', + 'test_comment')) node = error_handlers.transform(node, ctx) with self.compiled(node, {}) as result: with self.assertRaises(errors.GraphConstructionError): + # Here we just assert that the handler works. Its correctness is + # verified by errors_test.py. result.test_fn() def test_no_origin_annotation(self): diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index f906918ac0c1cefcdd521c807113e16853517b1a..996e99ee61b3713a03ff167b892101fca35eaeac 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -79,7 +79,7 @@ class ListTest(converter_testing.TestCase): ns = {'special_functions': special_functions} node, ctx = self.prepare(test_fn, ns) - def_, = anno.getanno(node.body[0].body[0].targets[0], + def_, = anno.getanno(node.body[0].targets[0], anno.Static.ORIG_DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.int32'), @@ -114,7 +114,7 @@ class ListTest(converter_testing.TestCase): return tf.stack(l) node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.body[0].body[0].targets[0], + def_, = anno.getanno(node.body[0].targets[0], anno.Static.ORIG_DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.int32') diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py index de1874321ec24a6a5d3625e8001df08b8e6edb1b..bee512abbc2e115d69bc9a5d53b6c54d428cc73a 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -43,7 +43,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, state_ops.assign) as result: with self.test_session() as sess: @@ -64,7 +64,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, state_ops.assign) as result: with self.test_session() as sess: @@ -84,7 +84,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, control_flow_ops.Assert) as result: with self.test_session() as sess: @@ -104,7 +104,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, state_ops.assign_add) as result: with self.test_session() as sess: @@ -125,7 +125,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body[0].body), 1) + self.assertEqual(len(node.body[0].body), 1) with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result: with self.test_session() as sess: @@ -147,7 +147,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {}) node = side_effect_guards.transform(node, ctx) - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) with self.compiled(node, {}, state_ops.assign, state_ops.assign_add) as result: diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py index 9cfa0666729a2dffd08de78734b9aaecdc6ce576..c527f98613a2ffebf35141d4dac85e972a89c93b 100644 --- a/tensorflow/contrib/autograph/converters/slices.py +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -36,12 +36,14 @@ class SliceTransformer(converter.Base): def _process_single_assignment(self, target, value): if not isinstance(target, gast.Subscript): return None + if not isinstance(target.slice, gast.Index): + return None template = """ target = ag__.set_item(target, key, item) """ return templates.replace( - template, target=target.value, key=target.slice, item=value) + template, target=target.value, key=target.slice.value, item=value) def visit_Assign(self, node): node = self.generic_visit(node) @@ -76,7 +78,7 @@ class SliceTransformer(converter.Base): opts=ag__.GetItemOpts(element_dtype=dtype)) """ return templates.replace_as_expression( - template, target=node.value, key=node.slice, dtype=dtype) + template, target=node.value, key=node.slice.value, dtype=dtype) def transform(node, ctx): diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py index 3c0f81e8bc0aa888fca672142e22a661cc813b87..c822d53a4a2810755fd6841af85544dd8fc76a5e 100644 --- a/tensorflow/contrib/autograph/converters/slices_test.py +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -38,7 +38,7 @@ class SliceTest(converter_testing.TestCase): return l[1] node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.int32') } @@ -59,11 +59,11 @@ class SliceTest(converter_testing.TestCase): return l[1] node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.int32') } - def_, = anno.getanno(node.body[0].body[0].body[0].targets[0], + def_, = anno.getanno(node.body[0].body[0].targets[0], anno.Static.DEFINITIONS) def_.directives[directives.set_element_type] = { 'dtype': parser.parse_expression('tf.float32') diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py index a93e4a806469db63e7d767563e64dadfe71f50ee..83a80c1f52123c325782a67c651e892163af83b3 100644 --- a/tensorflow/contrib/autograph/core/converter.py +++ b/tensorflow/contrib/autograph/core/converter.py @@ -233,7 +233,7 @@ class Base(transformer.Base): arg_values = [] for def_ in defs: if (directive not in def_.directives or - arg not in arg not in def_.directives[directive]): + arg not in def_.directives[directive]): continue arg_value = def_.directives[directive][arg] for prev_value in arg_values: diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/contrib/autograph/core/converter_testing.py index 2025e32817c11defa2818618e066eedc92c5db40..5ee2c3fffd7474cb8ca28349385a9d543e92a72d 100644 --- a/tensorflow/contrib/autograph/core/converter_testing.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -94,7 +94,8 @@ class TestCase(test.TestCase): return 7 try: - result, source = compiler.ast_to_object(node) + result, source = compiler.ast_to_object(node, include_source_map=True) + result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call) fake_ag.__dict__.update(operators.__dict__) @@ -144,6 +145,7 @@ class TestCase(test.TestCase): recursive=True, autograph_decorators=()): node, source = parser.parse_entity(test_fn) + node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py index e58745337a3faac5e9f351174465443fa52fd6bc..5a57d57e7d4c6461f05030b72cc9bfe1b33210db 100644 --- a/tensorflow/contrib/autograph/core/errors.py +++ b/tensorflow/contrib/autograph/core/errors.py @@ -31,9 +31,10 @@ import logging import sys import traceback -from tensorflow.contrib.autograph.pyct.origin_info import CodeLocation +from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.python.framework import errors_impl -from tensorflow.python.util import tf_inspect + +# TODO(mdan): Add a superclass common to all errors. class GraphConstructionError(Exception): @@ -65,51 +66,43 @@ class TfRuntimeError(Exception): return message + ''.join(traceback.format_list(self.custom_traceback)) -def _rewrite_frame(source_map, cleaned_traceback, stack_frame_indices): - """Rewrites the stack frames at the given indices using the given source map. +def _rewrite_tb(source_map, tb): + """Rewrites code references in a traceback. Args: - source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and - AG generated code. - cleaned_traceback: List[Tuple[text, text, text, text]], the current - traceback. - stack_frame_indices: Iterable[Int], frame indices to possibly rewrite if - there are matching source mapping keys. - + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + tb: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb. Returns: - None + List[Tuple[Text, Text, Text, Text]], the rewritten traceback """ - for frame_index in stack_frame_indices: - # (file_path, line number, function name, code) - file_path, line_number, _, _ = cleaned_traceback[frame_index] - source_map_key = CodeLocation(file_path=file_path, line_number=line_number) - found_mapping = source_map_key in source_map - if found_mapping: - cleaned_traceback[frame_index] = source_map[source_map_key].as_frame() - - -# TODO(znado): Make more robust to name changes in the rewriting logic. -def _remove_rewrite_frames(tb): - """Remove stack frames containing the error rewriting logic.""" - cleaned_tb = [] - for f in tb: - if 'ag__.rewrite_graph_construction_error' not in f[3]: - cleaned_tb.append(f) - return cleaned_tb + new_tb = [] + for frame in tb: + filename, lineno, _, _ = frame + loc = origin_info.LineLocation(filename, lineno) + origin = source_map.get(loc) + if origin is not None: + new_tb.append(origin.as_frame()) + else: + new_tb.append(frame) + return new_tb +# TODO(mdan): rename to raise_* def rewrite_graph_construction_error(source_map): """Rewrites errors raised by non-AG APIs inside AG generated code. - Meant to be called from the try/except block inside each AutoGraph generated - function. Only rewrites the traceback frames corresponding to the function - that this is called from. When we raise a GraphConstructionError at the end - it is then caught by calling functions, where they can be responsible for - rewriting their own frames. + This is called from the except handler inside an AutoGraph generated function + (that is, during exception handling). Only rewrites the frames corresponding + to the function that this is called from, so each function is responsible + to call this to have its own frames rewritten. + + This function always raises an error. Args: - source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and - AG generated code. + source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source + map belonging to the calling function Raises: GraphConstructionError: The rewritten underlying error. @@ -119,32 +112,17 @@ def rewrite_graph_construction_error(source_map): _, original_error, e_traceback = error_info assert original_error is not None try: - _, _, _, func_name, _, _ = tf_inspect.stack()[1] - # The latest function call is added to the beginning of a traceback, but - # when rewriting the traceback of multiple function calls the previous - # functions' except blocks may have already rewritten their own frames so - # we want to copy over all of the previous frames. We may have rewritten - # previous frames only if the error is a GraphConstructionError. + current_traceback = _cut_traceback_loops(source_map, + traceback.extract_tb(e_traceback)) if isinstance(original_error, GraphConstructionError): - cleaned_traceback = traceback.extract_tb(e_traceback) + # TODO(mdan): This is incomplete. + # The error might have bubbled through a non-converted function. previous_traceback = original_error.custom_traceback - cleaned_traceback = [cleaned_traceback[0]] + previous_traceback + cleaned_traceback = [current_traceback[0]] + previous_traceback else: - cleaned_traceback = traceback.extract_tb(e_traceback) - cleaned_traceback = _remove_rewrite_frames(cleaned_traceback) - - current_frame_indices = [] - # This code is meant to be called from the try/except block that wraps a - # function body. Here we look for all frames that came from the function - # that this wraps, look for any matching line numbers in the source - # mapping, and then rewrite them if matches are found. - for fi, frame in enumerate(cleaned_traceback): - _, _, frame_func_name, _ = frame - if frame_func_name == func_name: - current_frame_indices.append(fi) - break - if current_frame_indices: - _rewrite_frame(source_map, cleaned_traceback, current_frame_indices) + cleaned_traceback = current_traceback + + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) if isinstance(original_error, GraphConstructionError): original_error.custom_traceback = cleaned_traceback @@ -153,6 +131,7 @@ def rewrite_graph_construction_error(source_map): new_error = GraphConstructionError(original_error, cleaned_traceback) except Exception: logging.exception('Error while rewriting AutoGraph error:') + # TODO(mdan): Should reraise here, removing the top frame as well. raise original_error else: raise new_error @@ -161,70 +140,77 @@ def rewrite_graph_construction_error(source_map): del e_traceback +def _cut_traceback_loops(source_map, original_traceback): + """Check for cases where we leave a user method and re-enter it. + + This is done by looking at the function names when the filenames are from any + files the user code is in. If we find a case where we return to a user method + after leaving it then we cut out the frames in between because we assume this + means these in between frames are from internal AutoGraph code that shouldn't + be included. + + An example of this is: + + File "file1.py", line 57, in my_func + ... + File "control_flow_ops.py", line 231, in cond + ... + File "control_flow_ops.py", line 1039, in inner_cond + ... + File "file1.py", line 68, in my_func + ... + + Where we would remove the control_flow_ops.py frames because we re-enter + my_func in file1.py. + + The source map keys are (file_path, line_number) so get the set of all user + file_paths. + + Args: + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb. + + Returns: + List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed. + """ + all_user_files = set(loc.filename for loc in source_map) + cleaned_traceback = [] + last_user_frame_index = None + last_user_user_file_path = None + # TODO(mdan): Simplify this logic. + for fi, frame in enumerate(original_traceback): + frame_file_path, lineno, _, _ = frame + src_map_key = origin_info.LineLocation(frame_file_path, lineno) + if frame_file_path in all_user_files: + if src_map_key in source_map: + if (last_user_frame_index is not None and + last_user_user_file_path == frame_file_path): + cleaned_traceback = cleaned_traceback[:last_user_frame_index] + last_user_frame_index = fi + last_user_user_file_path = frame_file_path + cleaned_traceback.append(frame) + return cleaned_traceback + + +# TODO(mdan): This should be consistent with rewrite_graph_construction_error +# Both should either raise or return. def rewrite_tf_runtime_error(error, source_map): """Rewrites TensorFlow runtime errors raised by ops created in AG code. Args: - error: error_impl.OpError, an TensorFlow error that will have its traceback - rewritten. - source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and - AG generated code. + error: tf.OpError + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo] Returns: - A TfRuntimeError with a traceback rewritten according to the given - source mapping. + TfRuntimeError, the rewritten underlying error. """ - # Check for cases where we leave a user method and re-enter it in the - # traceback. This is done by looking at the function names when the - # filenames are from any files the user code is in. If we find a case where - # we return to a user method after leaving it then we cut out the frames in - # between because we assume this means these in between frames are from - # internal AutoGraph code that shouldn't be included. - # - # An example of this is: - # - # File "file1.py", line 57, in my_func - # ... - # File "control_flow_ops.py", line 231, in cond - # ... - # File "control_flow_ops.py", line 1039, in inner_cond - # ... - # File "file1.py", line 68, in my_func - # ... - # - # Where we would remove the control_flow_ops.py frames because we re-enter - # my_func in file1.py. - # - # The source map keys are (file_path, line_number) so get the set of all user - # file_paths. try: - all_user_files = set(k.file_path for k in source_map) - cleaned_traceback = [] - last_user_frame_index = None - last_user_user_file_path = None - last_user_user_fn_name = None - for fi, frame in enumerate(error.op.traceback): - frame_file_path, frame_line_number, _, _ = frame - src_map_key = CodeLocation( - file_path=frame_file_path, line_number=frame_line_number) - if frame_file_path in all_user_files: - if src_map_key in source_map: - original_fn_name = source_map[src_map_key].function_name - if (last_user_frame_index is not None and - last_user_user_file_path == frame_file_path): - if last_user_user_fn_name == original_fn_name: - cleaned_traceback = cleaned_traceback[:last_user_frame_index] - else: - cleaned_traceback = cleaned_traceback[:last_user_frame_index + 1] - last_user_user_fn_name = original_fn_name - else: - last_user_user_fn_name = None - last_user_frame_index = fi - last_user_user_file_path = frame_file_path - cleaned_traceback.append(frame) - - for fi in range(len(cleaned_traceback)): - _rewrite_frame(source_map, cleaned_traceback, [fi]) + cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback) + # cleaned_traceback = error.op.traceback + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) + op_name = error.op.name op_message = error.message rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback) @@ -263,7 +249,7 @@ def improved_errors(converted_function): ValueError: If converted_function is not generated by AutoGraph """ if (getattr(converted_function, 'ag_source_map', None) is None or - not converted_function.ag_source_map): + not isinstance(converted_function.ag_source_map, dict)): raise ValueError( 'converted_function must be the result of an autograph.to_graph call') try: diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py index 7be54563a1a86a56437f4da2941bf5187ce813a9..404c1f5456f9654724d068e3007fe9ced15cbf07 100644 --- a/tensorflow/contrib/autograph/core/errors_test.py +++ b/tensorflow/contrib/autograph/core/errors_test.py @@ -28,88 +28,77 @@ from tensorflow.python.util import tf_inspect def zero_div(): - return array_ops.constant(10, dtype=dtypes.int32) // 0 + x = array_ops.constant(10, dtype=dtypes.int32) + return x // 0 def zero_div_caller(): - a = zero_div() + 2 - return a + return zero_div() class RuntimeErrorsTest(test.TestCase): - def setUp(self): - self._fake_origin = origin_info.OriginInfo('new file', 'new func', 96, 0, - 'print("hello world!")') - - def test_error_replacement(self): - _, zero_div_lineno = tf_inspect.getsourcelines(zero_div) - src_map = { - errors.CodeLocation( - file_path=__file__, line_number=zero_div_lineno + 1): - self._fake_origin - } + def fake_origin(self, function, line_offset): + _, lineno = tf_inspect.getsourcelines(function) + filename = tf_inspect.getsourcefile(function) + lineno += line_offset + loc = origin_info.LineLocation(filename, lineno) + origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code', + 'test_comment') + return loc, origin + + def test_improved_errors_basic(self): + loc, origin = self.fake_origin(zero_div, 2) + zero_div_caller.ag_source_map = {loc: origin} + + ops = zero_div_caller() with self.assertRaises(errors.TfRuntimeError) as cm: - z = zero_div_caller() - zero_div_caller.ag_source_map = src_map with errors.improved_errors(zero_div_caller): with self.test_session() as sess: - sess.run(z) - expected = cm.exception - current_traceback = expected.custom_traceback - for frame in current_traceback: - self.assertNotEqual('zero_div', frame[2]) - self.assertTrue( - any(self._fake_origin.as_frame() == frame - for frame in current_traceback)) - - def test_error_not_found(self): - src_map = { - errors.CodeLocation(file_path=__file__, line_number=-1): - self._fake_origin - } + sess.run(ops) + + for frame in cm.exception.custom_traceback: + _, _, function_name, _ = frame + self.assertNotEqual('zero_div', function_name) + self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback)) + + def test_improved_errors_no_matching_lineno(self): + loc, origin = self.fake_origin(zero_div, -1) + zero_div_caller.ag_source_map = {loc: origin} + + ops = zero_div_caller() with self.assertRaises(errors.TfRuntimeError) as cm: - z = zero_div_caller() - zero_div_caller.ag_source_map = src_map with errors.improved_errors(zero_div_caller): with self.test_session() as sess: - sess.run(z) - expected = cm.exception - current_traceback = expected.custom_traceback - self.assertTrue(any('zero_div' in frame[2] for frame in current_traceback)) - for frame in current_traceback: - self.assertNotEqual(frame, self._fake_origin.as_frame()) - - def test_rewriting_error(self): - _, zero_div_lineno = tf_inspect.getsourcelines(zero_div) - src_map = { - errors.CodeLocation( - file_path=__file__, line_number=zero_div_lineno + 1): - None - } - with self.assertRaisesRegexp(tf_errors.InvalidArgumentError, - 'Integer division by zero'): - z = zero_div_caller() - zero_div_caller.ag_source_map = src_map + sess.run(ops) + + all_function_names = set() + for frame in cm.exception.custom_traceback: + _, _, function_name, _ = frame + all_function_names.add(function_name) + self.assertNotEqual('test_function_name', function_name) + self.assertIn('zero_div', all_function_names) + + def test_improved_errors_failures(self): + loc, _ = self.fake_origin(zero_div, 2) + zero_div_caller.ag_source_map = {loc: 'bogus object'} + + ops = zero_div_caller() + with self.assertRaises(tf_errors.InvalidArgumentError): with errors.improved_errors(zero_div_caller): with self.test_session() as sess: - sess.run(z) + sess.run(ops) - def test_no_ag_source_map(self): + def test_improved_errors_validation(self): with self.assertRaisesRegexp( ValueError, 'converted_function must be the result of an autograph.to_graph call'): - with errors.improved_errors(None): - pass - - def test_bad_ag_source_map(self): + errors.improved_errors(zero_div).__enter__() with self.assertRaisesRegexp( ValueError, 'converted_function must be the result of an autograph.to_graph call'): - src_map = None - zero_div_caller.ag_source_map = src_map - with errors.improved_errors(None): - pass + zero_div_caller.ag_source_map = 'not a dict' + errors.improved_errors(zero_div_caller).__enter__() if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD index 2a4a0f75e7a120d554c882025ad2a0e280913a6d..6c281485b4a3c4d09292a4d7af16330cdc44edd4 100644 --- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD +++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD @@ -16,12 +16,26 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +py_test( + name = "errors_test", + srcs = [ + "errors_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_windows"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + py_test( name = "keras_test", srcs = [ "keras_test.py", ], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ "//tensorflow:tensorflow_py", ], @@ -33,6 +47,7 @@ py_test( "list_literals_test.py", ], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b9159942bcf8837b97dfac000d8fb34d15a314 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py @@ -0,0 +1,162 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Error traceback rewriting integration tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib import autograph as ag +from tensorflow.python.util import tf_inspect + + +class ErrorsTest(tf.test.TestCase): + + def test_graph_construction_error_rewriting_call_tree(self): + + def innermost(x): + if x > 0: + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + return tf.zeros((2, 3)) + + def inner_caller(): + return innermost(1.0) + + def caller(): + return inner_caller() + + with self.assertRaises(ag.GraphConstructionError) as error: + graph = ag.to_graph(caller) + graph() + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_innermost_names = 0 + num_inner_caller_names = 0 + num_caller_names = 0 + ag_output_filename = tf_inspect.getsourcefile(graph) + for frame in custom_traceback: + filename, _, fn_name, _ = frame + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse(ag_output_filename in filename) + found_correct_filename |= __file__ in filename + self.assertNotEqual('tf__test_fn', fn_name) + num_innermost_names += int('innermost' == fn_name) + self.assertNotEqual('tf__inner_caller', fn_name) + num_inner_caller_names += int('inner_caller' == fn_name) + self.assertNotEqual('tf__caller', fn_name) + num_caller_names += int('caller' == fn_name) + self.assertTrue(found_correct_filename) + self.assertEqual(num_innermost_names, 1) + self.assertEqual(num_inner_caller_names, 1) + self.assertEqual(num_caller_names, 1) + + def test_graph_construction_error_rewriting_class(self): + + class TestClass(object): + + def test_fn(self): + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + + def inner_caller(self): + return self.test_fn() + + def caller(self): + return self.inner_caller() + + # Note we expect a TypeError here because the traceback will not be + # rewritten for classes. + with self.assertRaises(TypeError): + graph = ag.to_graph(TestClass) + graph().caller() + + def test_runtime_error_rewriting(self): + + def g(x, s): + while tf.reduce_sum(x) > s: + x //= 0 + return x + + def test_fn(x): + return g(x, 10) + + compiled_fn = ag.to_graph(test_fn) + + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_test_fn_frames = 0 + num_g_frames = 0 + ag_output_filename = tf_inspect.getsourcefile(compiled_fn) + for frame in custom_traceback: + filename, _, fn_name, source_code = frame + self.assertFalse(ag_output_filename in filename) + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse('ag__.' in fn_name) + self.assertFalse('tf__g' in fn_name) + self.assertFalse('tf__test_fn' in fn_name) + found_correct_filename |= __file__ in filename + num_test_fn_frames += int('test_fn' == fn_name and + 'return g(x, 10)' in source_code) + # This makes sure that the code is correctly rewritten from "x_1 //= 0" to + # "x //= 0". + num_g_frames += int('g' == fn_name and 'x //= 0' in source_code) + self.assertTrue(found_correct_filename) + self.assertEqual(num_test_fn_frames, 1) + self.assertEqual(num_g_frames, 1) + + def test_runtime_error_rewriting_nested(self): + + def test_fn(x): + + def g(y): + return y**2 // 0 + + s = 0 + for xi in x: + s += g(xi) + return s + + compiled_fn = ag.to_graph(test_fn) + + # TODO(b/111408261): Nested functions currently do not rewrite correctly, + # when they do we should change this test to check for the same traceback + # properties as the other tests. This should throw a runtime error with a + # frame with "g" as the function name but because we don't yet add + # try/except blocks to inner functions the name is "tf__g". + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + num_tf_g_frames = 0 + for frame in custom_traceback: + _, _, fn_name, _ = frame + self.assertNotEqual('g', fn_name) + num_tf_g_frames += int('tf__g' == fn_name) + self.assertEqual(num_tf_g_frames, 1) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py index 73125eb452fc3f3f94a8323d677341345931c4ea..7e7ef5a3e2bbf6a15936eb181c9c4112f8b820e6 100644 --- a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py +++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py @@ -44,6 +44,33 @@ class ModelWithStaticConditional(object): return x +class BasicBlock(tf.keras.Model): + + def __init__(self): + super(BasicBlock, self).__init__() + self.conv1 = tf.keras.layers.Conv2D(8, 3) + self.pool = tf.keras.layers.GlobalAveragePooling2D() + self.dense = tf.keras.layers.Dense(3) + + def call(self, x): + x = self.conv1(x) + x = self.pool(x) + x = self.dense(x) + return x + + +class CompoundModel(tf.keras.Model): + + def __init__(self): + super(CompoundModel, self).__init__() + self.block = BasicBlock() + + @autograph.convert(recursive=True) + def call(self, x): + x = self.block(x) # pylint: disable=not-callable + return x + + class KerasTest(tf.test.TestCase): def test_basic(self): @@ -57,6 +84,20 @@ class KerasTest(tf.test.TestCase): model = ModelWithStaticConditional(True) self.assertEqual(model.call(), 25) + def test_recursive_true(self): + with self.assertRaisesRegexp(NotImplementedError, + 'Object conversion is not yet supported.'): + with tf.Graph().as_default(): + model = CompoundModel() + model.build(tf.TensorShape((None, 10, 10, 1))) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + sample_input = tf.random_uniform((1, 10, 10, 1)) + output = model(sample_input) # pylint: disable=not-callable + self.assertEqual(sess.run(output).shape, (1, 3)) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb index a3109fa5db2b895817545f5ff611c6979375f85b..7e9cc54d4cafa64e4cd3b48f9376b1b2b4d3575e 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -392,7 +392,7 @@ "output_type": "stream", "text": [ "Got error message: assertion failed: [Do not pass zero!]\n", - "\t [[Node: f/Assert/Assert = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" + "\t [[{{node f/Assert/Assert}} = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" ] } ], diff --git a/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb b/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..32742bec7ee4a412aabb6640b5a1329353ebfc9d --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb @@ -0,0 +1,519 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "moMkWaT_TTHi" + }, + "source": [ + "This Colab illustrates the differing overhead* between a custom, vectorized graph operation and a loop over a tensor\n", + "that computes the same function. The loop is implemented in TensorFlow Eager mode using Python syntax and control-flow, and using AutoGraph which takes a python function and converts it into graph mode. In AutoGraph the Python loop is converted into a tf.while_loop.\n", + "\n", + "The actual computation, summing a small number of scalar values, takes very little time to compute, so the graphs below are showing the overhead of the differing approaches. As such, this is more of a \"micro-benchmark\" than a representation of real-world performance of the three approaches.\n", + "\n", + "*Note the differing scales of the included plots" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "a0X_rfvuav98" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "EdxWv4Vn0ync" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "erq3_S7QsjkU" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import\n", + "from __future__ import division\n", + "from __future__ import print_function\n", + "\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "import math\n", + "import time\n", + "import random\n", + "from colabtools import adhoc_import\n", + "from tensorflow.contrib import autograph as ag\n", + "from tensorflow.python.framework import function" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1JgnsXooa2RP" + }, + "source": [ + "### Testing boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UyD5LLjVZzny" + }, + "outputs": [], + "source": [ + "# Test-only parameters. Test checks successful completion not correctness. \n", + "burn_ins = 1\n", + "trials = 1\n", + "batches = 2\n", + "max_elements = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4_NBL0RQa8gY" + }, + "source": [ + "### Speed comparison parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Yq6daecyiJV5" + }, + "outputs": [], + "source": [ + "#@test {\"skip\": true} \n", + "burn_ins = 3 # Batches not counted in the average\n", + "trials = 10 # Batches run per vector-size (and averaged)\n", + "batches = 1000 # Number of random vectors summed over per trial\n", + "max_elements = 100 # Vectors of size 0 to this-1 will be executed and plotted" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "fiR8m13CbKH2" + }, + "source": [ + "### Random input" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "d8vrTlyNXuxc" + }, + "outputs": [], + "source": [ + "# Construct a random num x 1 tensor\n", + "def get_elements(num):\n", + " return tf.random_uniform(shape=(num, 1), maxval=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ILJ6SbF3bXFQ" + }, + "source": [ + "## Graph mode" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "vovRf597X55n" + }, + "outputs": [], + "source": [ + "def tf_sum(elements):\n", + " # Using custom vectorized op\n", + " return tf.reduce_sum(elements)\n", + "\n", + "def run_trial(num):\n", + " elements = get_elements(num)\n", + " return tf_sum(elements)\n", + "\n", + "\n", + "\n", + "graph_means = []\n", + "for num in range(max_elements):\n", + " with tf.Graph().as_default():\n", + " durations = []\n", + " foo = run_trial(num)\n", + " \n", + " with tf.Session() as sess:\n", + " \n", + " for _ in range(burn_ins):\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " for _ in range(trials):\n", + " \n", + " start = time.time()\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration) \n", + " \n", + " graph_means.append(np.mean(durations)) " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 278, + "status": "ok", + "timestamp": 1532447361278, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "Jm9Blkyx90Eq", + "outputId": "d83cd51f-7e56-4d73-f7df-bb157dee46df" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa8AAAEcCAYAAABwNTvaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3WdgVFXegPFnZtI7kJ5QQwlCKIGERGroEoSAYZFVEFAR\ngV1XXHvbFWEtK6xlUVgRXMuLiqBSZQUh9E5ChxRIn/ReJjNz3g8hNxkzCUMJkHh+X2DmtnPPnXv+\n95R7ohJCCCRJkiSpGVHf6QRIkiRJ0vWSwUuSJElqdmTwkiRJkpodGbwkSZKkZkcGL0mSJKnZkcFL\nkiRJanaaJHitWLGCV199tSl23WINHz6cAwcONPlxXnzxRd5///0mP87dIikpiUmTJtGvXz++/PLL\nO52cFi0+Pp4HHnjgTifjmsaPH8+RI0du6T5/b/eVObcqX9966y3Wrl17zfWsbmTnffv2RaVSAVBe\nXo6NjQ1qtRqVSsUbb7zBE088cSO7vW5paWmMGDGCs2fPolY3n0rkiy++iLe3N0899dSdTspdoSmv\n46effsqAAQPYsGGD2eVbt27l888/5/z58/Tq1Yv//ve/JsvPnTvHyy+/TGJiIgEBASxevJjAwEBl\n+bvvvsu6detQqVQ88MADPPvssxZv29SmT5/OxIkTiY6Ovi3H++CDD3jsscea9BjDhw9n8eLFhIeH\n3/A+Nm3adAtT1LgVK1bwySefoFKp0Ov16PV67OzsEELg7+/Pxo0bCQwMxN7eHpVKhRACa2trDh8+\nfNvSeCPMlWG3Kl8fffRRpkyZQnR0NFZWDYeoGyopTpw4wfHjxzl+/Di+vr6sWLFC+W78+PE3nOjr\nJYRQLrjUfDXldUxPT6dz584NLndzc2PmzJnMmTOn3rKqqirmz59PVFQUR44cISoqinnz5qHX6wFY\nu3YtO3fuZOPGjfz000/s2rWLb775xqJtm4PruR7Z2dkcOnSIESNGNGGKbo7BYLjtx3ziiSeUsvHv\nf/87ffv25fjx45w4cYKNGzcCoFKp+Omnn5Tv73TguhP5VJeHhwcBAQHs3Lmz0fVu+jFXCFHvR/7R\nRx8pT6BpaWkEBgayfv16hg0bxoABA1i7di2nTp1iwoQJhIaGsmjRIpPt161bx7hx4xgwYACPPfYY\n6enpZo89ffp0APr3709wcDCxsbEIIVi+fDnDhw9n4MCBvPDCC5SUlJjdPj8/n7lz5xISEsKAAQN4\n+OGHlWWBgYGkpKQon+s2Cxw+fJihQ4fy6aefcu+99zJ48GB++eUXdu/ezZgxYxgwYAArVqwwe8xv\nv/2WjRs38umnnxIcHMyTTz6pLDt37hwTJkwgJCSEhQsXotPplGW//vorUVFRhISEMG3aNC5cuGB2\n/wAJCQnMnj2bAQMGcN9997F169YG121sv8OHD2fVqlVMmDCBvn378sorr5Cbm8vjjz9OcHAws2fP\npri4WFn/5MmTPPjgg4SEhBAVFWVyE06fPp3333+fadOmERwczKOPPkpBQYGyDEyvY3JyMtOnT6d/\n//6Eh4ezcOHCBs9hx44djB8/ntDQUGbMmEFiYiIAjzzyCIcOHeKNN94gODiYK1eu1Ns2PDycsWPH\n4uHhUW/Z4cOHMRgMzJgxA2tra6ZPn44QgoMHDwLwww8/MHv2bDw9PfH09GTWrFlKDe/QoUONblvX\nli1b6jW3rVmzhnnz5gGg0+l4++23iYiIYNCgQfztb38z+W388ssvREVF0a9fP0aPHs3evXtZtmwZ\nx44dY9GiRQQHB/Pmm28CcPz4caKjowkJCWHKlCmcOHHC5BotW7aMadOm0adPH1JTU1m/fj0jR44k\nODiYkSNHNvh0vW/fPnr06IGNjQ0AK1eu5M9//rPJOm+++SaLFy8GoKSkhJdffplBgwYxdOhQ/vWv\nf5mUI99++y3jxo0jODiY8ePHc+7cOZ577jkyMjJ48sknCQ4OZtWqVUD965+QkKDsZ/jw4fznP/9R\nfsMGg8GkiT4kJITg4GCCg4Pp27cvgYGBSnnT2L1x9uxZJk+eTL9+/Xj66aeprKw0my+WsPQhobGy\nraac/fbbbxk8eDCDBw9m9erVJtuuXLmSUaNGERYWxtNPP01RUZHJtuvWrSMiIoKZM2cC8NRTTzFo\n0CBCQkKYPn26kq8NlWF181Wn07F48WIGDx7MkCFDWLJkCVVVVUBt+bl69Wql/Fy/fr3JuYaEhLBr\n165rZshNiYiIEPv37zf57sMPPxTPPvusEEKI1NRU0a1bN/H666+LyspKsW/fPhEUFCTmz58v8vLy\nRGZmpggPDxdHjhwRQgjxv//9T4wePVokJiYKg8EgPv74YzF16lSzx05NTRWBgYHCaDQq33333Xdi\n9OjRIjU1VZSVlYkFCxYoafmt9957T7z++uvCYDAIvV4vjh49qiwLDAwUycnJyucXXnhB/Otf/xJC\nCHHo0CFxzz33iOXLlwu9Xi++/fZbERYWJp555hlRVlYmLl26JIKCgkRKSorZ49bdV918nDJlisjO\nzhaFhYXivvvuE2vXrhVCCHH69GkRHh4u4uLihNFoFBs2bBARERFCp9PV23dZWZkYOnSo2LBhgzAa\njeLs2bNiwIABIj4+vt6xr7XfiIgIMXXqVJGbmyu0Wq0IDw8XkyZNEufOnRM6nU7MmDFDfPTRR0II\nITIzM0VoaKiIiYkRQgixf/9+ERoaKvLy8oQQQjz88MNi1KhR4sqVK6KyslI8/PDD4r333mvwOi5c\nuFB88sknQgghKisrxbFjx8zmZWJioujTp4/Yv3+/0Ov14j//+Y8YNWqUqKqqUo773Xffmd22rm+/\n/VZMnz7d5LvVq1eLxx9/3OS7J554QqxevVoIIUS/fv1EbGyssuzUqVMiODjYom3rKi8vF8HBweLK\nlSvKdw888IDYsmWLEEKIN998Uzz55JOiqKhIlJaWirlz54qlS5cKIYSIjY0V/fr1U+5BrVYrEhMT\nzZ57QUGBCAkJET/99JMwGAxi06ZNIiQkRBQUFCjrR0REiPj4eGEwGERxcbEIDg4Wly9fFkIIkZ2d\nrfyOfuvtt98Wb7zxhvI5LS1N9OnTR5SUlAghhDAYDGLgwIFKfj355JPi9ddfFxUVFSI3N1dMmTJF\nfPPNN0IIIbZs2SKGDBkiTp8+LYQQIjk5WaSnpwshqn+TBw4cUI5zresfEREhoqKiRGZmpqisrFS+\n+22ZJYQQS5cuFQ8//LDQ6/WN3hs6nU5ERESIzz//XOj1erFt2zbRo0ePevf0b61fv1788Y9/rPd9\nt27dTMqahjRWttWUswsXLhQVFRXiwoULIiwsTDnP1atXi6lTpwqtVit0Op147bXXxMKFC022ff75\n50V5ebmST99//70oKysTOp1OLFmyREycOFFJS0NlWM3x/vWvf4mpU6eKvLw8kZeXJ6ZOnSref/99\nIURt+fnhhx8KvV4vdu3aJXr37i2KioqUfW3fvl1MmjSp0fy4LR1FKpWK+fPnY2Njw7333ou9vT2R\nkZG0atUKLy8v+vfvz9mzZwH45ptvmDNnDh07dkStVjNnzhzOnz9PRkZGYwFY+f+mTZuYOXMmfn5+\n2Nvbs3DhQrZs2YLRaKy3nZWVFdnZ2aSmpqLRaOjXr5/ZfZpjbW3N3Llz0Wg0jBs3jvz8fB555BHs\n7e3p3LkznTt3brR2ZM6MGTNwd3fHxcWFiIgIzp07B8B3333Hgw8+SFBQECqViqioKGxsbIiNja23\nj19//RV/f3+ioqJQqVR0796d0aNHs23btnrrWrLfhx9+mNatW+Pp6Un//v3p3bs3gYGBWFtbM2rU\nKCWNP/30E8OGDWPw4MFAdY2mZ8+e7N69W9nX5MmTadeuHTY2Ntx3333KtjXq5rmVlRVpaWlotVps\nbGwIDg42m2dbt25l2LBhhIeHo9FoePTRR6moqDCpUdyosrIynJ2dTb5zcnJSnnZ/u9zZ2ZmysjKL\ntq3Lzs6OESNGKLWay5cvk5SUpDTBrVu3jhdffBFnZ2ccHByYM2eOsu66deuIjo5W+oA8PT3p2LGj\n2fPZtWsXHTp04P7770etVhMZGUmnTp349ddflXUmTZpEQEAAarUajUaDRqPh4sWLVFZW4u7uTkBA\ngNl9FxcX4+joqHz29fXlnnvu4ZdffgHgwIEDODg40KtXL3JyctizZw8vvfQStra2tG7dmkceeYTN\nmzcr5/TYY4/Ro0cPANq2bYuPj4+y77q/E0uu/4wZM/Dy8lJqheZs2bKFTZs28eGHH6LRaBq9N2Jj\nY9Hr9cyYMQONRsOYMWPo2bNng/u2xKRJkwgJCSE0NFSpnf6WJWXbn/70J2xtbenatSuTJ09W8vTb\nb7/lL3/5C56enlhbWzN//nx+/vlnZVuVSsWf/vQn7OzslHyaPHky9vb2yvrnz59vsBXLXFrnz59P\nq1ataNWqFQsWLODHH39UlltbWzNv3jw0Gg1Dhw7FwcGBpKQkZbmjo6NJq445NzRg40a0adNG+b+d\nnR3u7u7KZ1tbW+WmT09PZ/Hixbz99ttAbX+IVqs1+QE3JCsrC19fX+Wzn58fer2enJwcPD09TdZ9\n7LHH+PDDD5k9ezYqlYopU6aY7fswx83NTRm0YmdnZ/Yca87JUnW3t7e3Jzs7G6jOkx9//FEZLSeE\nQK/Xk5WVVW8f6enpnDx5ktDQUGVdg8FAVFSU2XWvtd+6abK1ta33ue5127p1q1IQ1uyrbsd63Wtu\nb2/faP4899xz/Otf/yI6OlrplzI3ku2311ulUuHj44NWq21w35ZycHCod7OWlJTg5ORkdnlJSQkO\nDg4WbftbkZGRvPPOO8ybN49NmzYxcuRIbGxsyMvLo7y83OTcjUajUoBnZmYydOhQi87nt3kF1UGm\nbl55e3sr/7e3t2fZsmWsWrWKl156iX79+vHcc8/RqVOnevt2cXGhtLS03jlt3ryZiRMnsmnTJqU/\nPD09Hb1ez6BBg4Daroea+zszM5N27drd0DmZu/51z8mcs2fPsmjRIlavXo2bm5uSxsbuDS8vL5N9\n+Pn5WZTehmzYsIG2bds2uk5jZRtUn3vdc/X19eXSpUvK+SxYsEAZECWEwMrKStkWTPPJaDSydOlS\nfv75Z/Lz81GpVKhUKvLz8xv8DTeWVl9fX5Nyxc3NzWRwlp2dncnvp7S0tN7D32/dtuBlKW9vb558\n8kmLBn7UBI+6PD09TfrI0tLSsLKyMik4azg4OPD888/z/PPPk5CQwPTp0+nVqxdhYWHY29tTXl6u\nrJudnX3Nm6CpeHt7M3fuXItGcfr4+DBgwAClP+BW7deS40ZFRfHGG29c97bmrmObNm2UvtBjx44x\na9YsQkND693gnp6eyg1aIyMj45Zcqy5durBmzRqT7y5evKj00XXu3Jnz588TFBQEVPdZdunSpdFt\n6/ar1jVo0CBefPFFzp8/z+bNm3nppZcAaNWqFfb29mzatKnewxdUX8O6fbN1/TZfPT092b59u8l3\n6enpDBkypMFtBg4cyMCBA9HpdCxbtoxXX32Vr776qt6xunXrZvJkDTB27FjeeecdtFotv/zyizKY\nxcfHB1tbWw4dOmT22nt7e5OcnGzxOd3M9c/Ly2PBggW8/vrrJiNBG7s3jhw5Uu/hKD093eKAe6Ma\nK9syMjIQQpCRkaHUvDMyMpTfjI+PD0uWLKFv37719puWlgaY5u3GjRv59ddf+fzzz/H19aW4uJiQ\nkJDrSmtaWppSU09PTzf7+21IQkLCNUfm3pZmw2s1wdU1bdo0VqxYQXx8PFDdHGGuyQugdevWqNVq\nkx96ZGQka9asITU1ldLSUpYtW0ZkZKTZIdi7du1StnVwcFCaSaB6wMamTZswGo3ExMTc0vdC3N3d\nGyxwzPnDH/7A2rVriYuLA6qbpHbv3m225jJs2DCSkpL48ccf0ev1VFVVcerUKWUQw43u91omTJjA\nzp072bt3L0ajkcrKSg4fPmxRDcjcddy2bZuyrYuLC2q12uw1vO+++9i1axcHDx5Er9ezatUqbG1t\n6dOnj0XpNhqN6HQ69Hq9yf8BQkNDUavVfPHFF+h0OuUpfMCAAQBERUWxZs0atFotWq2WNWvWMHny\n5Ea3DQsLM5uOmuand955h6KiIgYOHAigtAgsWbKEvLw8ALRaLXv37gUgOjqa9evXc/DgQYQQaLVa\n5Vr/9nc2dOhQrly5wubNmzEYDGzZsoXExEQiIiLMpik3N5edO3dSXl6OlZWVco+YM3DgQM6cOWMy\nkKR169aEhITw4osv0rZtW6XG5uHhwcCBA1myZAklJSUIIUhJSVHusSlTpvDZZ59x5swZAJKTk5Vu\nA3d3d1JTU5Vj3Mz1NxgM/OlPf2LChAmMHTvWZFlj90afPn2wsrLiiy++wGAwsH37dk6dOnXN490s\nS8q25cuXU1FRwaVLl1i/fj2RkZEATJ06laVLlyrBLy8vjx07dijb/baMLi0txcbGBhcXF8rKynjv\nvfdMgtu1yrDIyEg+/vhj8vLyyMvLY/ny5UycONHicz1y5IjJQ5U5Nx28zD05XWudxj6PHDmSxx9/\nnKeffpr+/fszYcIE9uzZY3a/dnZ2zJ07l2nTphEaGkpcXBzR0dFMnDiRhx9+mFGjRmFvb88rr7xi\ndvvLly8zc+ZM+vbty7Rp03jooYeUp4uXX36ZnTt3EhISwubNmxk5cuRNnWNd0dHRxMfHExoayoIF\nC665fs+ePVm0aBFvvPEGoaGhjBkzpsH3lhwdHfnss8/YsmWLMurovffeMylULN3v9ZyTt7c3y5cv\nZ8WKFYSHhxMREcFnn32m3BSNbWvuOp46dYopU6YQHBzM/Pnzefnll802zXTs2JF3332XRYsWER4e\nzq5du/jkk0+U90Ou9fv88ccf6dWrF2+88QbHjh2jd+/eygv21tbWLF++nA0bNhAaGsr69etZvny5\nsu8HH3yQiIgIJkyYwIQJE4iIiOAPf/iDRduaExkZyYEDB7jvvvtMCqS//vWvtG/fnj/84Q/079+f\n2bNnc/nyZQB69erFkiVLWLJkCf369WPGjBlKQT9jxgy2bdvGgAEDWLx4MW5ubnzyySesWrWKsLAw\nVq1axYoVK3B1dTWbV0ajkdWrVzNkyBDCwsI4cuQIr7/+utm0t2nThrCwMKWPq8b48eM5cOAA999/\nv8n3b7/9NlVVVURGRhIaGspTTz2lNJOPHTuWuXPn8swzzyjXv7CwEIA5c+awfPlyQkNDWb169Q1d\n/5rvMjMzOX78OJ9//rky2jA4OJjMzMxG7w1ra2s+/PBD1q9fT2hoKNu2bWP06NENXtdrsaQMBSwq\n20JDQxk1ahSzZs3iscceU5rtH3nkEUaMGMHs2bPp168fDz74oBKYzaUhKioKHx8fhgwZwvjx4+vV\n2K5Vhs2bN4+ePXsyYcIEJk6cSM+ePZk7d65FeZCVlUVCQsK1y1xxPdUiSZKkBiQkJPDCCy/w3Xff\n3emk/O6kpaUxcuRIzpw506wmbDDn7bffpl27dkybNq3R9WTwkiRJauaa62xDN+P3cZaSJEktnKXN\njy2FrHlJkiRJzY6seUmSJEnNzl33ntfN0OsN5Odf/zDvlqhVKweZF1fJvKgl86KWzItaHh6NvxB8\nN2pRNS8rK/PvoPweybyoJfOilsyLWjIvmrcWFbwkSZKk3wcZvCSpBSsq1VFcVv8FdUlq7mTwkqQW\n7J9rT/LBurhrryhJzUyLGrAhSVIto1GQnlOKrY3s25FaHlnzkqQWqri8CqMQlFfq0Rvq/z07SWrO\nZPCSpBaqsKT2T9MXl1XdwZRI0q0ng5cktVAFJbUDNeSgDamlkcFLklqoujWvIhm8pBamyYNXTEwM\nY8eOZcyYMaxcubLe8qNHjzJ58mR69OhR76+8QvWfTh8yZAhvvvlmUydVklqUgtI6Na9S2WwotSxN\nGryMRiOLFi1i1apVbNq0ic2bN5OQkGCyjq+vL2+99Va9P1ZX4/333yc0NLQpkylJLZJpn5eseUkt\nS5MGr7i4ONq3b4+fnx/W1tZERkaa/OlpqA5eXbt2NTud/+nTp8nLy2PQoEFNmUxJapEK6/R5FckB\nG1IL06TBS6vV4uPjo3z28vIiKyvLom2FELz99ts899xzyL/aIknXr6BU1ryklqtJX1K+maDz9ddf\nM2zYMLy8vK5rX81xduSmIvOi1u8xL4rL9djbaiivNFCpF0oe/B7zoiEyL5qvJg1e3t7epKenK5+1\nWi2enp4WbXvixAmOHz/O119/TWlpKXq9HkdHRxYuXNjodtnZxTeV5pbCw8NZ5sVVv8e8EEKQV1iB\nn4cjqVkl5BaUkZ1d/LvMi4bIvKjVHIN4kwavoKAgkpOTSUtLw8PDg82bN7N06dIG169bu/rnP/+p\n/H/Dhg2cOXPmmoFLkqRqZVdn1WjlZEthSaUcKi+1OE3a56XRaHj11VeZPXs248ePJzIykoCAAD74\n4AN+/fVXAE6dOsXQoUPZtm0br7/+eoOjDiVJslxBcXV/l6uTDS4ONnLAhtTiqEQLGw0hmwGqySaR\nWr/HvDhzOY/31p5kwsAOJKQVcuZyPp88MxQ/X7ffXV405Pf4u2hIc2w2lDNsSFILVPOOl5uTLc6O\nNoCc31BqWWTwkqQWqOYdL1cnG5ztrwavctnvJbUcMnjdJpU6A2u2nufoecvec5Okm1EzKa+bky0u\njtYAFMkpoqQWRAav2+RkfA4xseks/+E0KzeeobRCFiRS0ym8+oKyq6MNzg41zYay5iW1HPIvKd8m\nlzOLAGjjYsfBM1ouJBfwxIQedG3rdodTdnuUlFdx8lIO+cUV5JfoqKjUM3FQR7xaO9zppLVIBSU6\nVICLow3ODtU1L9nnJbUkMnjdJpczilEBf58dwi9HU/lp32VWbz3PP+aE3emk3RZfbr/A4XOmTaat\nXGyZMqzzHUpRy1ZYUomzgzVWGjUuV2te8l0vqSWRzYa3gVEILmuL8XF3xMHOmgmDOtLZz4WsvDKq\n9IY7nbwmpzcYiUvIpY2LLU//oTcvPhwMQFp26R1OWctVUKrD1ckWoE7Nq+UEr5LyKjnn6e+cDF63\ngTavjEqdgQ7ete9SeLdxQADa/PI7l7Db5GJKARU6A327eBDUqQ1d/N1wdbIhNbvkTietRarQ6anU\nGXB1qq5x1fZ5tYxmw0upBfz5/T38c+1Jsgta/v0jmSeD121wObP6RUiT4NXaEYDM3LI7kqbbKTY+\nF4Dend2V7/w9nMgrqqRMDly55WqGybs5Vte87Gw0WGnULabmde5KvvLvq6sO8cvRFIyyFva7I4PX\nbXA5oyZ4uSjfeV8dqJCZ1/KDV1xCDrY2GpPBKf4e1cE7VTYd3nIFJbVTQwGoVCpcHK1bzFD5mt/M\nlIgArDVqvv7lEv/dduEOp0q63eSAjdvgcmYRKhW09XJSvvNu8/sIXpl5ZWjzy+nX1QNrq9pnJX+P\n6rxIyy753Yy4vF56g5HSCj2uV2fIsFRhae07XjWc7W3IyL27HxTKKqrYeiiZvKJKist1lJRVMaS3\nL8P6+pmsl5Zdgr2tFWND23FvTx+WfHGUg2cyeWhUF6ytNHco9dLtJmteTcxoFFzRFuPn7oitde2N\n5e5qh0atQtvCg1dsfA4AvTq3Mfm+JnjJmlfDvvrfRZ7/ZD9FpdfX3KdMylsn6Dk7WqPTG6mo1N/S\nNN5K2w4ns/nAFQ6cyeR0Yh6XM4v5+UiKyTpVegPavHL8PBxRqVS4OtrQr6snOr2R88kFdyjl0p0g\ng1cTy8gtRVdlNGkyBLDSqPFwsyczr6xFj5pSgleAu8n3Pm0cUKmQgzYaUFpRxf7TmeiqjCSmF13X\ntgVmal41w+VrmhTvNlV6I7tPpuNoZ8U/ngjj44VD6d6+Fdq8MpMX+tNzyjAKQVuP2laMoIDqB6O4\nq32rUjVtfhknLmbf6WQ0GRm8mpgyWMOn/qzN3q0dKK3QU1x+Y30Rd3vQK6uo4lJqIR19XOo1fdlY\na/Bq5UBqduldfx51ZeSWsnzDKfKLmzYIHDidSZXeCNS+4G6pwt/0eUHtcPnrrcXdLkfPZ1FcVsXg\n3r54tXLA1kZDJ9/qB76kOsG75mGnps8UoIu/K/a2GmITcprVb6mprd5yng/XnyIhvfBOJ6VJyODV\nxMwN1qihDNpoZMThtzvj+b9fLtX7Pq+ogqc+2Mv2w8m3KKW33umkPAxGQe+ANmaX+3s4Ul6pNwkE\nCWmFrN1xCYPReLuSeV02xCRy9EI2/zuaUm9ZUZmO9JybbwYVQhATm45GrQJqH4AsVTuvYW3wuttr\nXjuOp6ICIur0b9UEr8SM+sHLr07Ny0qjpkeH1uQUVpDxOxi9a4n84koupVQ3o/58uP5vtSWQwauJ\nXc4sQqNW0dbTsd6yaw3aiE8tZNvhZP53NKVeoRgTm05JeRUHz2pvfaJvQn5xJTmF5RSWVHL8apNF\n3SHyddX2e9U2Ha7deYntR1K4eBf2X+QUlnPs6jntP52J3lAbYIUQfPh9HH9bfZicm3z3KDG9iNTs\nUvp29aCNiy2XM4quq0ZRWKrD0c7KZPCCU03N6y4MXkkZRSSmF9G7szsebvbK9518XQFMmk1r+kjr\n1rygtlk6LkE2HQIcOZ+FADRqFccuZLXI9+Fk8GpCeoOR5KwS/DwczY6Camy4vBCC73cnKJ93nUxT\n/m8wGtkTlwHAFW3xLZvk1yjETTW7nLyUwzP/3sdzHx/g6Y/2cfhcFm5ONrSrM8qyLj9lxGF1gZSZ\nV0ZCWnVBdSn17mvq2Hk8DSGqB9sUleo4lVhbUF5MKSAhrQi9QbD54JWbOs7u2HQAhvb2pYO3C0Vl\nVdfVTFlYUqnMrlGjtuZ19zUb7jyWCsDwfqajCl0dbWjjYkdiem3wTs0uoY2LLQ521ibrKv1eCTm3\nIcV3vyPntKhUED0sACEw21LQ3Mng1YTSc0qp0tcfrFGjsWbDM5fzuJBSQI+OrXF1tGH/qUwqq6qn\nkjqVmEd+cSU21mqE4JbUUsor9Ty7fD///fnG3pcxCsH6mERUKgjr4UVIoCd9u7jz4IguqFQqs9v4\ne9a861WD0HifAAAgAElEQVRd89p/OlNZdjH17qp5VeoMxJxMx8XBmicm9gBg79UHCIBth6qbb53s\nrdkbl0FuYcUNHae8Us/hc1rcXe3o3qGV0ldqadNhld5gdnh9zSwbhXdZzauoTMehc1l4tXbgng6t\n6y3v5OtCSXkV2YUVFJfpKCzRmTQZ1nB1tKGDtzOXUgspq7gzIypv9kXp3zah36icgnIS0osIbNeK\nEf38aeVsy57YjBb3lyxk8GpCjQ3WgOpOdAdbq3o1r+paVyIAU4YFMLi3L2VXCzWAmJPVT+YPDAkA\n4Fxy/k2n9XRSdUDcfTKds5fzrnv7ExdzSM0uYcA9Xsy5vwdPRvXkTw/0IrS7V4PbeLjZY2OtJjW7\nFKMQHDidgZ2NBk83exLSiu6qfq/9pzMoq9QzrK8fAb6utPNyIjY+l8KSStJySolNyKWznysPjuiM\nwXjjta9DZ7XoqowM7u2LWqVSHnwsHbRRaKa/C8DlarPh3RS8Siuq+GFPEnqDkeHBfqjNPOQo/V7p\nhXWaDM3X5HsFtMFgFDf0+71Ze2LTmbd0NxdTbuyhq7xSz5v/Pcqrnx666Vlnjlz9m4ED7vHCSqNm\nZH9/KqsM7L5abrQUMnjdpLKKKtbuuGS2Tfn81WlsOjZQ81KpVHi3cSC7oNykoD52IZsrmcWEdvek\nnZczQ3v7olLBrhPp5BVVEJuQQwdvZ4b19cPaSq0c52bUHVL7xc8XrmvCYCEEG/cloQLGh3eweDu1\nSoWfuyMZuaWcu5xPblEl/QM96d6hFZVVBpK1t3YYfUZuKa+tOsSWRgJLld5ITGw6b311nP/+fIH8\n4kqMQvC/o6lo1CplQMHgXr7VAfeMlp+v1rrGDmjHgHu88Gxlz57Y6mt1PfQGI7+eSEOtUjEoyAeA\n9lenFKsZ+NMQo1EQl5DD59vOA9RrNqyted1cs2FZhf6mA2BiehErN55h4Uf72HUiDRdHGwb29DG7\nbm3wKjI70rCumr7V2NvcdGgUgi0Hr6CrMvLpprOUX+e7dEIIVm0+R0ZuGWWVemJiM669USMOn8tC\no1YR3NUDgKG9/bCz0fDL0RQycku5kJzP0fNZN9w6cLeQM2zcpB/2JvHL0VS0eWU8NaW38n1uYQVH\nzmfh08bBZGaN3/Ju7UBiehE5BRV4tXbAaBRs2JOIWqUianAnANq42tGrUxtiE3L5+pdLCAFD+/hi\nbaWmi78rZy/nU1Sqw+U6Z2KooTcYiU3IpY2LHX27uvPL0VS2HExm4qCOFm1/Mj6H5KwSQrt74utu\nvmBpiJ+HE0kZxayPqa5pDuzpTV5RdQ3wUkoBHX3MB/7rVVSqY9m3seQUVrBuVwKujjYMDKotMCt0\nenYcS+WXo6nKDBUXUwrYfyqD3p3dycwr496e3kpQGHCPF9/svMTO46nkF1fi1dqBPl3cUatUjA/v\nwGdbzrH54BWmj+5mUfqMQvDZ5nOkZFXXXls5Vx/Hyd4ad1c7LmcWI4Qw2wR75nIen289T87VwqiT\nr4vJuQHY2miwsVYrf6SyMWcu53HyYg4TB3fEyb62b6msooq/rT5CaYWev88Kwb3O4ApLZeWX8Y8v\nj2EwCrxa2TOkty/3BvngYGe+KGrv5YxGrSIpvQjd1Wbzhmpe7b2dcXG0ITY+l437krC11mBtrUGv\nN1JRZUBXZSCwfSt6mGmevBnnLuejzS/H0c6KnMIKvv01nkfGBlq8/ZaDVzh+MZvO/q6kaEv45VgK\no0L80ajN1y1yCsrJK640OzONNq+MK9pigjq1Ua6dg50VQ3r7sv1ICi//55Cybr+uHsyfHHSdZ3v3\naPKaV0xMDGPHjmXMmDGsXLmy3vKjR48yefJkevTowfbt25Xvz58/z4MPPsj999/PxIkT2bJlS1Mn\n9bplFZTz6/HqgRSxCbkmo6K2H0nBYBTcN6C92eaQGjX9XhlXmw73nsogI7eMgUHeyjJAmSLn+MVs\nbG00SnNc9/atADh/E02H55PzKa/U07erO5MGd6KVsy2bD1y2aOoqIQQ/7b2MCrh/oGXBri7/q8Eu\nKaMId1c7urR1o0vb6lFmF2/RoI0KnZ7318WRU1jB0D6+ONhasWbreaWJJz6tkL99doTvdydSWWVg\nbGg73pkbzsz7AnG0t1aaYUb1b6vs08nemr5dPMgprMBgFIwJbatc5/CeXni42bEnNp2kjPrNffFp\nhRw4k6k0Dwkh+HL7RQ6e1RLg58LM3xR8HXyq+33MPSnnFJTz8YbTFJRUMqS3D6/PDOGVGf3xM/MQ\n4eJgQ2EjfSpVeiPf7LzEe2tPsuN4Kh//cFoZUSmE4L8/XyCnsILySj2fbjqL0Xj9fTw198VDo7qy\nZE4Y94W1b3T6KxtrDf4eTlzRlnA5sxiNWqWM0v0ttUpFv64elJRXsWFPEmt3xvPFzxf4vx2X2BCT\nyOYDV1i+4dRN9f2cuZxXr1lv5/HqASd/eqAX/h5O7D6ZbvGoxzNJeayPSaSVsy0LJgUxMKj64e3Y\nBfMvF5+8lMNrnx3mra+Om32lpKZrIbS7p8n348LaE9bDi0G9fIgMb88fR3Zh6vDm/bf0mrTmZTQa\nWbRoEWvWrMHT05Po6GhGjBhBQECAso6vry9vvfUWn332mcm29vb2vPPOO7Rr146srCwmT57MkCFD\ncHJquBZzu22IScRgFAzr68euE2n8uDeJp//Qm5LyKmJi02nlbEtYj4b7fMB00EZlOwM/7EnExkqt\n1LpqBHVqQxsXO3KLKhjQ3Qt72+pLF6gEr4JG+5cac+JidTNLcBcP7G2tmDaiC8t/OM0XP1/gmQf7\nNBp8YxNyuaItJiTQ02yBeS1+nrXX896e3qhVKtxd7WntYsul1AKT2sbxi9moVSr6dDE/9N4cg9HI\ne18dIymjiHt7ejNjTDdCAz1Z+m0sH60/xb09vatHYgkYG9qO8fd2UGoBQ9zsCbvHi10n0hDUNuHV\nGNzLhyPns3BxsGZgT2/le41aTdSgTvxn01kWfX6UPp3dGX9vB/KLK9l2+IoyotLaSk2/bh7YWGmI\niU2nracTT0/pja2N6cjUjt7OHD2fxeXMYpPajt5g5OMfz1BWqWfWuEAG9/JtNC+cHayVl8Lr1uCM\nRkFSRhFf/HyB5KwSvFrZ08bVjrOX8/lmZzwPjerKgTOZHD6XRYCfC26Othy7mM22w8mMC2tv8bUo\nLtOxNy6DNi62DO3j2+BAnt/q5OvCFW0xydoS/D0csdI0/Mz94IguhPf0plJnoPJqbcvaSo2tjYaz\nl/PZdiiZnw+nMHlI7f1VXqlnQ0wirVxs6dGhNf6eTmZ/86cTc1n6bSyd/Vx57o99sdKoyS2s4GR8\nDu29neni78rj99/DG2uOsHrrOSYO6kh8aiHxqYU4OVgzc2wg/nV+76cTc1nx0xk0ahXzJvXExdGG\nUf3b8uvxNLYfSTG5n4UQbD5whQ0xiVhZqfFws2P7kRTSckqZO7EH1ho1567ks/dUBlYaFX27eJik\n3cXRhjn397Aov5uLJg1ecXFxtG/fHj+/6lpDZGQkO3bsqBe8gHo/5Pbta28KT09P2rRpQ15eXqPB\nq6lmD9AbjPxn41kc7ayYEtEZe1srkjKKOHRWS3tvZx4e3ZXM3FJOJeYSn1bI2aQ8KqsMRA3u2OiN\nBqbD5bcfTaGgRMf4e9srzUY11GoVYwe045ud8QwPrh1S3MHbGTsbjfJnIqC6U/7AuSw6eTnh1dr8\nU2oNoxCcuJSNk721UuPp182D3gHVzZQ/H07mvgHmC6hKnYFvdsYDcP/ADo0epyF1m4DC6wSALv5u\nHDqrJTOvDJ82jqTnlLJ8w2mMQjDzvkCG9DZfUFfqDOyJS+dCcgEZeWVk5ZehNwgC27kx875AVCoV\n3Tu0ZvqYbqzZep7tR1Jo42LHY+O7061dq3r7s7HWMDq0ndlj3dOhNSOC/Qls71bvVYjwnt44O1jz\n077LnIzP4WR8bT9Mn87udPB25sCZTA6eqX5S9mrtwDNT+9QbAg61f0onKbOI/oG1T9TrdiWQlFFE\neA8vpY+sMc4ONlTpiymr1KPNK+dMUi4XUwtJSCukQlfdJDe4lw/TRnZBCFjyxTF2HEvF0c6K7UdS\nsLPR8Pj9PXCwtSI+vZANMYn06NC6XlBvyK8n0tDpjYwKaXfN+6KuTr4u/HqiuoWjoSbDGtZWajr7\nuZpd1sXfjf2nM/nlaAqjQ9pSU7x/uf0iB85Uj3T9jgScHawZHdKWyN/032692rcZn1bI97sTmDq8\nC7tjq1+fGN7XD5VKRVtPJ6IGd+T73YnKTPf2thqyCspZ9N+jTBvRhUG9fPhhTxJbDl7BSqNi1n3d\nCbj6TptXawd6d3bnZHwO8WmFdPZzpbBUx5fbL3DsQjatXWz50+ReeLjZs3LjGeIScnn5P4eoqNSj\nuzojy5Devg02w7YkTXqGWq0WH5/am8rLy4tTp05d937i4uLQ6/W0a2e+EKnx0GtbmTaiC6NC2ja6\nXg0hBMnaElq52CrvwZizcd9lpenodFIecyb0YMPVPpo/DAtQ+qfe+uo463YlkJ5TiuPVduZr8Wxl\njwpISC/k8Dktzg7WDQaL4cF+DO7lg02dCX41ajVd27oRl5BLXlEFGo2at78+oTT5dW3rxuBePrT1\ndEKjVqFWq2jtbKc83SdlFFFQomNgkLfSxq5SqZg1rjuvrz7M97sS6eLvZrZA+L8dl9DmlTE6pO01\nC5WGuDra0M7TiVbOtni1qg20Xf1dOXRWy6XUQnzaOPLdr/EYhcDGWs3nW89jY6UmrEdtsCutqFL6\nrEquTrdlb6uhracTgR3bEBna1qTAHNLbl0qdgdyiCiYO6qjUZK+HWq3iodFdG1zes1MbenRszfkr\n+ew4noaTfXWhWNMveP/ADlxKLeR0Uh7D+vg22GdpbtDGiUvZbD+SgndrB6aP6WZRLaZmiqjnPt5P\neWXtgBzv1g508XelXzdPetWZDeVP0b1YtOYIP+27DMBj47vjebXm9+i47iz9NpaVG8/w2swQk0mn\nzdFVGdhxLBV7WysG97p2oK2rZtAGgF8DgzUsYWutYVxYe9buuMTPh5OZ2641h85qOXAmk44+zozs\n35azSXnEJuTy/e5EurVtRWf/6t/9lcxizl3Jp4u/K8VlVfx8OIWOPi7EXJ2PMfSe2lrSfQPao1ar\nsLPW0KWtG77ujsTG5/DZ5nP89+cL/Lg3icJSHZ5u9syN6lHvVZpRIW05GZ/DtkPJdGvnxg97Eimv\nNNC1rRvzonoqv5M/P9CLDXsS2XYouTroBbShd2f3BoN3S9OkwetWzDOWlZXFc889xzvvvHPNdVs5\n27J25yU6+LtxbyNNKAaj4NDpDNbviufClXzcnG15ZVYo3drX78i9mJzP5oNX8Gxlz+A+fqzfFc8/\nvjyGEBAc6MmQkOpA4+HhTJ/DKZy8VN1WPXVkV9r513+SN8ejtYPyou7MyHss3q5GSA9v4hJyOZ9W\nxC+Hk8nMK2NYP3/yCiuIi8+pN3zXyd6a52f0p09XT7ZcnTpmWP92eHjUPkF7eMDz00N45ZN9rNx4\nlg+eGaaMWAM4cCqdmNh0Ovq6MDe69039KYoPnx2uzAZQY0AvP77YfpHk7FLSCyqITcilZ0AbHpvQ\nk5c/3senm89hY2dNhc7A8QtZnE7IRVdlwMnemgdHdWP0gPa4u9k1Wqj/cdw9N5zm6+Hp6aL8Tswt\nGxh87YctH3dHkrNKcHd3YvfxVFZuPIuNlZqXZoXS1teywiqwozv7TmXiaGfN4D7+BHfzpGdAm3oj\nE2t4eDjz4sxQ/v7pQQb38WPCsNp39iI8nLmYVsSmfUms3nqBF2eG1KtNGY0C9dVruu3AZYrLqoge\n3uW6f99t2jjhaGdFaYWeHp09TH6n1yt6VDe2H0lmx7FURoZ14MvtF7Cz0fDCI6H4ejgxYRicTcrl\n+Y/28u2ueN57aihqtYrPf74IwEP3dcfd1Z6F78ew8qczGAVEDQ3A39d08MSM8T1NPo/2dKFvdx/+\n+dVRziblMaSPH/On9DZb03Z3d2Ld7gSOX8zm+MVsHO2tmTu5B2PDO5jcIwBzo/vw+OTe9b7/PWjS\n4OXt7U16eu27BVqtFk9Pz0a2MFVSUsLcuXNZuHAhvXr1uub6rz0Wxgsf7eWfXx3jWaPR7BNIfGoh\nqzafRZtfPbS9W1s3LqYW8MK/9zF7XKDJ07yuysA/vzyK0Sh4ZGwg3du3orOPM//ZdJbCEh0TwtuT\nnV37NDxuQDtOXsrG2kpNeHdPk2WN8XS1IyuvDK9W9gR3bmPxdjXaXu3A/vTH0wCMCPbnqWnB5OSU\nkFVQzqGzWopLdRiMAl2VgUPntLy+8iAPjujMvtg0bKzU+Le2r3dcb1dbJgzsyA97k3jn8yPMmXAP\ndjZW5BdX8v7aE1hbqZk9rjsF+bd+Pjk7DTjaWRF3KZv4q4NRJg/uiIuthqem9Oa9tSf56LtYZX1f\nd0cGBnkzrI9fdS1Krycnp3potYeH83Xn6d2mrYcjh3NK+ft/DnDsQjZ2NhrmTOyBk7Xa4nML7+7B\nkL5j0FfolCCkK9eRXd5wc7uvmx3LFgzC3laj5GeN+8Pbk5RWwOGzmbzz+REeHd8dtUpFUkYRq7ec\nI6ugnF4B7oQGevJ9TCIateq67ou6Ovq4cDopDxdbzU1fy7Gh7fj6l0s8/9FedFUGZt4XiDVC2a+H\nkw1h93hx8KyWDTsvck+HVuw5mYa/hyNtW9ujUqmYProrqzafA2BANw+L07RwSm+0+WXVk3IXV1Ba\nbH64emRYez7+4TThPb2JHhaAi4MNeblN9xcYbuaB4E5p0uAVFBREcnIyaWlpeHh4sHnzZpYuXdrg\n+nVralVVVcyfP5+oqChGjx5t0fE6+7vxZFRPPlgXxwfr4njuj31NmrMOn9Py6aZzGI2Cwb18GDug\nHT5tHDmVmMsnP55m5cazxKcV0qNja/zcHdlxLI2M3DJG9vNXRvV1a9eKxY+FUVymqzdUuLO/K9HD\nAnB1tLmuYettPZ04nZRH9LCA6+oLqOHvWftkOqS3D9NG1T4he7rZc/+9HUzWH9rHj482nOLrqxP+\nBnf1aLDZZ/y9HbiQUsDJ+BzmLY3BzckGlUpFaYWe6aO73tAgDUuoVSo6+7kSe3XUVngPb6V5pbOf\nKwun9mZPXAZd/F3p0aE1rV3smiQdd4sO3i4cPpfFsQvZtPV0Yt6knibNrJZQq1S0drEju/L6Rts1\n1H9ibaVm/uQg3lt7kgNnMnGws8LWWsPWQ1cQAtq42HH0fBZHrza5DwryqdeXa6lpI7uQll16S67z\n0D6+bD2UTH5xJcFdPcw2Y06J6MyJSzl8vzuB+FR3jEIwJrSdcl8NDPKpfgfQKK7Zr1yXWq3Cp821\n75ngrh6s+OswpeYq1acSTfw3BGJiYli8eDFCCKKjo5kzZw4ffPABQUFBREREcOrUKRYsWEBRURG2\ntrZ4eHiwceNGfvrpJ1566SW6dOmijI76xz/+QWBg4+9PZGcXExObzpqt51EBfbq4Mzqk7dVO1kTs\nbDTMi+pJz06mM52n55Tywbo4sn7zsrFXawf+Nuvabfo3o6xCT0pWsdkBA5baE5tOdmEFUYM6olar\nrlnbyCuq4IPv40jWljBnwj2E3ePd4LrFZTq2HLxCanYp2rwycgsrCO7qwbxJPS0eMXYjth68wne7\nErC2UvOPOWE3XHC1hJpXSlYJiz4/SngPLx4a1dWk3/N6NEVelJRX8fZXx0m7Onm0u6sds8Z1J7Cd\nGylZJRw+l8UVbTEzxnQzmXj3Tjp+MZtD57N4eGQXk+bwujbtv6y8f9jK2Za354bf0MNlc9Aca15N\nHrxut5ob89iFLDYfuGIyJ1wrZ1v+MqU3bT3NDy6o0Ok5dyWf9JxS0nJKyS+qZOqIzg3OTXg3s6SQ\nqqwycCWzmC7+rtcVhPQGIxq1qkkDF0BqVgmvrz7MxEEdmXAD75DVaAnBC6rz/WYLz6bKi/ziSlb+\ndIa2Xk5MHtIJO5u7f7TbtfKiSm/glU8PkV1QwZSIgAYHUrUEMnjdBer+GIUQJKQVsf1oCuWVemaP\n637DzRbNTUspsEvKq3C0s7qpQNlS8uJWkHlRy5K8iE8tZPfJNP44qusNjUhtLppj8Gq5V4PqId+d\n/V2V4a5S81N3eiJJut1k+XH3apkNuJIkSVKLJoOXJEmS1OzI4CVJkiQ1OzJ4SZIkSc2ODF6SJElS\nsyODlyRJktTsyOAlSZIkNTsyeEmSJEnNjgxekiRJUrMjg5ckSZLU7MjgJUmSJDU7MnhJkiRJzY4M\nXpIkSVKzI4OXJEmS1OzI4CVJkiQ1OzJ4SZIkSc2ODF6SJElSsyODlyRJktTsyOAlSZIkNTtNHrxi\nYmIYO3YsY8aMYeXKlfWWHz16lMmTJ9OjRw+2b99usmzDhg2MGTOGMWPG8MMPPzR1UiVJkqRmwupa\nK6SkpLBu3ToOHTpEZmYmtra2BAYGMmbMGEaPHo2VVcO7MBqNLFq0iDVr1uDp6Ul0dDQjRowgICBA\nWcfX15e33nqLzz77zGTbwsJC/v3vf7NhwwaEEEyePJkRI0bg7Ox8E6crSZIktQSNBq/XXnuNM2fO\nMHbsWP7617/i7u5OZWUlCQkJ7N27l5UrV/K3v/2NPn36mN0+Li6O9u3b4+fnB0BkZCQ7duyoF7wA\nVCqVybZ79+5l4MCBSrAaOHAge/bsYdy4cTd+tpIkSVKL0GjwGjFiBG+88Ua977t168a4ceMoKCgg\nJSWlwe21Wi0+Pj7KZy8vL06dOmVRwsxtq9VqLdpWkiRJatkaDV5Dhw5tdGM3Nzfc3NwaXC6EuLFU\nNbDtb2tn5nh4yGbFGjIvasm8qCXzopbMi+brmn1eAG+99Rbz58/H3t6eGTNmcPbsWf7+978zceLE\nRrfz9vYmPT1d+azVavH09LQoYd7e3hw6dEj5nJmZSVhY2DW3y84utmj/LZ2Hh7PMi6tkXtSSeVFL\n5kWt5hjELRptuH//fpydndm7dy9eXl78/PPP9QZYmBMUFERycjJpaWnodDo2b97MiBEjGly/bm1r\n0KBB7N+/n+LiYgoLC9m/fz+DBg2yJLmSJElSC2dRzavGkSNHGDVqFF5eXhY14Wk0Gl599VVmz56N\nEILo6GgCAgL44IMPCAoKIiIiglOnTrFgwQKKior49ddf+eijj9i4cSOurq7MmzePBx54AJVKxYIF\nC3BxcbnhE5UkSZJaDpWwoGNq1qxZ+Pn5sW/fPn744QccHR2ZNGkSGzduvB1pvC6yGaCabBKpJfOi\nlsyLWjIvarXYZsP33nuPzp07s2zZMlxdXcnMzGTWrFlNnTZJkiRJMsuiZsPWrVszc+ZM5bO/vz/+\n/v5NlSZJkiRJalSjwSssLKzRvq0DBw7c8gRJkiRJ0rU0Gry+//57ANatW0dBQQFTp05FCMH333+P\nl5fXbUmgJEmSJP1Wo8GrZlqnI0eO8OWXXyrfv/LKKzz88MM8/vjjTZs6SZIkSTLDogEbWVlZ5OXl\nKZ/z8vLIzs5uskRJkiRJUmMsGrDxyCOPEBUVxbBhwwDYvXs3TzzxRFOmS5IkSZIaZFHweuihh+jX\nrx9HjhxBCMFDDz1Et27dmjptkiRJkmSWxTNsBAYGEhgY2JRpkSRJkiSLWBS8jh8/zrvvvktKSgoG\ngwEhBCqVSg6VlyRJku4Ii4LXyy+/zLx58+jTpw9qtUVjPCRJkiSpyVgUvOzs7Lj//vubOi2SJEmS\nZBGLqlFDhgxh9+7dTZ0WSZIkSbKIRTWvb775hhUrVuDo6IiNjY3s85IkSZLuKIuCV800UZIkSZJ0\nN7AoePn5+aHX60lKSkKlUtGhQwesrK7r71hKkiRJ0i1jUQQ6deoUf/7zn5UmQ71ez4cffkiPHj2a\nOn2SJEmSVI9FwWvx4sUsWbKE8PBwAA4ePMiiRYtYu3ZtkyZOkiRJksyxaLRheXm5Erig+u98lZeX\nN1miJEmSJKkxFgUve3t7Dh48qHw+fPgw9vb2TZYoSZIkSWqMRc2GL730Ek899RQ2NjYAVFVV8cEH\nH1h0gJiYGJYsWYIQggceeIA5c+aYLNfpdDz//POcOXOGVq1asWzZMnx9fdHr9bzyyiucOXMGo9HI\nxIkT620rSZIk/T5ZFLx69erF9u3bSUpKQghBp06dsLa2vuZ2RqORRYsWsWbNGjw9PYmOjmbEiBEE\nBAQo66xbtw5XV1e2b9/Oli1bePfdd1m2bBnbtm2jqqqKjRs3UlFRwbhx4xg/fjy+vr43fraSJElS\ni2BRs+H+/fupqKiga9eudOvWjfLycoteUI6Li6N9+/b4+flhbW1NZGQkO3bsMFlnx44dTJo0CYAx\nY8YozZMqlYqysjIMBgPl5eXY2Njg5OR0vecnSZIktUAWBa933nnHJHA4OTnxzjvvXHM7rVaLj4+P\n8tnLy4usrCyTdbKysvD29gZAo9Hg7OxMQUEBY8aMwd7enkGDBjF8+HAeffRRXFxcLDopSZIkqWWz\nqNmwZjqoGmq1GoPBYNF217tOzbHi4uLQaDTs27ePgoIC/vjHPxIeHo6/v78lSZYkSZJaMIuCl6Oj\nI7GxsfTu3RuA2NhYHBwcrrmdt7c36enpymetVounp2e9dTIzM/Hy8sJgMFBSUoKrqyubNm1i8ODB\nqNVqWrduTXBwMKdPn75m8PLwcLbklH4XZF7UknlRS+ZFLZkXzZdFwevZZ59l/vz5dO7cGYD4+Hg+\n+uija24XFBREcnIyaWlpeHh4sHnzZpYuXWqyTkREBBs2bKB3795s27aNsLAwAHx8fDh48CATJkyg\nrKyM2NhYZs6cec1jZmcXW3JKLZ6Hh7PMi6tkXtSSeVFL5kWt5hjEVcKStj2gsLCQkydPIoSgb9++\nuLq6WnSAmJgYFi9ejBCC6Oho5syZwwcffEBQUBARERHodDqeffZZzp07h5ubG0uXLsXf35+ysjJe\nfNEYr50AABg9SURBVPFFEhISAHjggQeYNWvWNY8nf4zV5I1ZS+ZFLZkXtWRe1GrRwSspKYmEhARG\njhxJaWkpVVVVuLm5NXX6rpv8MVaTN2YtmRe1ZF7UknlRqzkGL4tGG27YsIEnn3ySf/zjH0B139Vf\n/vKXJk2YJEmSJDXEouD1+eef8/333+PsXB2dO3XqRE5OTpMmTJIkSZIaYlHwsra2xtHR0eQ7jUbT\nJAmSJEmSpGuxKHi5ubkpf4gS4Mcff1ReLJYkSZKk283iiXmfeeYZkpKSGD58OHZ2dnzyySdNnTZJ\nkiRJMsui4NWxY0e+++47Ll++jBCCjh07ymZDSZIk6Y6xqNkwKSkJvV5PQEAAGRkZrFq1isLCwqZO\nmyRJkiSZZVHw+stf/oJarSYlJYXXX3+dlJQUnn/++aZOmyRJkiSZZVHwUqvVWFtbs3v3bqZNm8ai\nRYvIyMho6rRJkiRJklkWBa/Kykq0Wi07d+5U5h60cGIOSZIkSbrlLApejzzyCJGRkTg6OhIUFERK\nSorywrIkSZIk3W4Wz21Yl8FgwGAwYGNj0xRpuilyrrJqct62WjIvasm8qCXzolaLm9vw9OnTZr/X\naDTY2Nig0+mUWd8lSZIk6XZp9D2vFStWUF5ezvjx4+nduzfu7u5UVlaSlJTEnj172L17Ny+88AIB\nAQG3K72SJEmS1Hjw+vDDD4mLi+Obb77h3//+N5mZmdjb29O1a1dGjhzJV199hZOT0+1KqyRJkiQB\nFsyw0atXL3r16nU70iJJkiRJFrFotKEkSZIk3U1k8JIkSZKaHRm8JEmSpGZHBi9JkiSp2bEoeOXm\n5vLXv/6Vhx56CIDz58/zf//3f02aMEmSJElqiEXB65VXXqFfv34UFRUB0KlTJ77++muLDhATE8PY\nsWMZM2YMK1eurLdcp9Px9NNPM3r0aKZOnUp6erqy7Pz58zz44IOMHz+eCRMmoNPpLDqmJEmS1LJZ\nFLy0Wi3Tpk1T/gCljY0NavW1NzUajSxatIhVq1axadMmNm/eXG9GjnXr1uHq6sr27dt55JFHePfd\nd4HqKaiee+453njjDTZt2sQXX3yBtbX19Z6fJEmS1AJZFLysrExfBysqKrJoVvm4uDjat2+Pn58f\n1tbWREZGsmPHDpN1duzYwaRJkwAYM2YMBw8eBGDv3r0EBgbStWtXAFxdXVGpVJYkV5IkSWrhLApe\no0eP5rXXXqO0tJT169cze/ZsHnjggWtup9Vq8fHxUT57eXmRlZVlsk5WVhbe3t5A9ZyJzs7OFBT8\nf3v3HhxVef9x/L1sAlJMgpiQRaS0JraQGqAzKsERIYBZIITsBiIMUsKlpdoBKqFYwck4crXGyUhk\nOhIBKzRMa4HIJRBSgxI6XGy1hZkCRUEn3JJwS5NgypLN8/sjP3YbgrBWNvEkn9df7Nlnz373yzN8\nOGfPPqeKL774AoAZM2aQlpbG6tWrA/1MIiLSxt12hQ2An/70p2zdupXq6mr27NnDT37yE1JTU2/7\nukCOzm4cY4zBZrPh9Xr55JNP2LRpE506dWLq1Kk89NBDvvuJiYhI+xVQeAGMHTuWsWPHfq2dOxyO\nJhdgVFRU0L1792ZjysvLiY6Oxuv1UltbS0REBA6Hg0ceeYSIiAgAnnjiCY4cOXLb8LLi0v7Bol74\nqRd+6oWfemFdAYXXxYsX+f3vf09ZWRn19fW+7StWrLjl6+Lj4ykrK+PMmTNERUVRWFhITk5OkzGJ\niYkUFBTQv39/ioqKfOH0+OOPs3r1aq5evYrdbuevf/0rU6dOvW2tuj9PI92ryE+98FMv/NQLPyuG\neEDh9Ytf/IK4uDgGDRrku+IwEHa7naysLKZPn44xhvHjxxMTE0Nubi7x8fEkJiaSnp7O/PnzSUpK\nomvXrr5wCw8PZ9q0aYwbNw6bzcbQoUMZMmTI//YpRUSkTQnoTspjx45l69atLVHPN6b/STXS/yr9\n1As/9cJPvfCz4pFXQFcb9u/fn3/961/BrkVERCQgAZ02nDhxIpMnT8bhcNCpUyff9o0bNwatMBER\nka8SUHjNnz+fZ555hri4uK/1nZeIiEgwBBRenTp1YsaMGcGuRUREJCABfec1ePBgSktLg12LiIhI\nQAI68nr33XfJy8ujS5cudOzY0bcKxv79+4Ndn4iISDMBhdemTZuCXYeIiEjAAgqvnj17BrsOERGR\ngN0yvObPn092drZvlYsb6VJ5ERFpDbcMr4yMDAB+/etft0gxIiIigbhleG3YsIFly5bx6KOPtlQ9\nIiIit3XLS+WPHj3aUnWIiIgELKDfeYmIiHyb3PK04fHjxxk0aFCz7fqdl4iItKZbhtf3vvc98vLy\nWqoWERGRgNwyvDp27KjfeImIyLfOLb/zCg0Nbak6REREAnbL8Hr33Xdbqg4REZGA6WpDERGxHIWX\niIhYjsJLREQsJ+jhVVpaysiRI3E6nTe97N7j8TB37lySkpKYMGECZ8+ebfL82bNn+fGPf8zbb78d\n7FJFRMQighpeDQ0NLF68mDVr1rB9+3YKCws5ceJEkzEbN24kIiKC4uJiMjIyyM7ObvL8K6+8wpAh\nQ4JZpoiIWExQw+vw4cP07t2bnj17EhoaSnJyMiUlJU3GlJSU4Ha7AXA6nU1W7Xj//ffp1asXsbGx\nwSxTREQsJqjhVVFRQY8ePXyPo6OjqaysbDKmsrISh8MBgN1uJzw8nKqqKurq6li9ejWzZs0KZoki\nImJBAd1J+X9ljPnaY66vm5ibm8vUqVPp3LlzwPsCiIoK+/qFtlHqhZ964ade+KkX1hXU8HI4HE0u\nwKioqKB79+7NxpSXlxMdHY3X66W2tpaIiAgOHz5McXEx2dnZVFdX06FDBzp16sTTTz99y/c8f74m\nKJ/FaqKiwtSL/6de+KkXfuqFnxVDPKjhFR8fT1lZGWfOnCEqKorCwkJycnKajElMTKSgoID+/ftT\nVFREQkICAPn5+b4xK1eupEuXLrcNLhERaR+CGl52u52srCymT5+OMYbx48cTExNDbm4u8fHxJCYm\nkp6ezvz580lKSqJr167Nwk1ERORGNhPol0kWodMAjXRKxE+98FMv/NQLPyueNtQKGyIiYjkKLxER\nsRyFl4iIWI7CS0RELEfhJSIilqPwEhERy1F4iYiI5Si8RETEchReIiJiOQovERGxHIWXiIhYjsJL\nREQsR+ElIiKWo/ASERHLUXiJiIjlKLxERMRyFF4iImI5Ci8REbEchZeIiFiOwktERCxH4SUiIpYT\n9PAqLS1l5MiROJ1O8vLymj3v8XiYO3cuSUlJTJgwgbNnzwKwb98+0tLSGDt2LOPGjePAgQPBLlVE\nRCwiqOHV0NDA4sWLWbNmDdu3b6ewsJATJ040GbNx40YiIiIoLi4mIyOD7OxsALp168aqVavYunUr\nr7zyCs8//3wwSxUREQsJangdPnyY3r1707NnT0JDQ0lOTqakpKTJmJKSEtxuNwBOp5P9+/cD0KdP\nH6KiogB48MEH8Xg8XLt2LZjlioiIRQQ1vCoqKujRo4fvcXR0NJWVlU3GVFZW4nA4ALDb7YSHh1NV\nVdVkTFFREXFxcYSGhgazXBERsYiQYO7cGPO1xxhjsNlsvseffvopOTk5rF27NqD3jIoK+3pFtmHq\nhZ964ade+KkX1hXU8HI4HL4LMKDxSKx79+7NxpSXlxMdHY3X66W2tpaIiAgAysvLmTVrFq+++ir3\n339/QO95/nzNnfsAFhYVFaZe/D/1wk+98FMv/KwY4kE9bRgfH09ZWRlnzpzB4/FQWFjI8OHDm4xJ\nTEykoKAAaDw9mJCQAEB1dTU///nP+dWvfsWAAQOCWaaIiFhMUMPLbreTlZXF9OnTGTNmDMnJycTE\nxJCbm8sHH3wAQHp6OpcvXyYpKYl33nmHefPmAZCfn09ZWRm//e1vcblcuN1uLl26FMxyRUTEImwm\nkC+mLESnARrplIifeuGnXvipF346bSgiItICFF4iImI5Ci8REbEchZeIiFiOwktERCxH4SUiIpaj\n8BIREctReImIiOUovERExHIUXiIiYjkKLxERsRyFl4iIWI7CS0RELEfhJSIilqPwEhERy1F4iYiI\n5Si8RETEchReIiJiOQovERGxHIWXiIhYTtDDq7S0lJEjR+J0OsnLy2v2vMfjYe7cuSQlJTFhwgTO\nnj3re27VqlUkJSUxatQo/vKXvwS7VBERsYighldDQwOLFy9mzZo1bN++ncLCQk6cONFkzMaNG4mI\niKC4uJiMjAyys7MB+Oyzz9i5cyc7duzgrbfe4uWXX8YYE8xyRUTEIoIaXocPH6Z379707NmT0NBQ\nkpOTKSkpaTKmpKQEt9sNgNPp5MCBAwDs3r2b0aNHExISwv3330/v3r05fPhwMMsVERGLCGp4VVRU\n0KNHD9/j6OhoKisrm4yprKzE4XAAYLfbCQsLo6qq6qavraioCGa5IiJiEUENr0BO891sjM1m+8rt\nIiIiIcHcucPhaHIBRkVFBd27d282pry8nOjoaLxeLzU1NUREROBwODh37pxvXHl5ebPX3kxUVNid\n+wAWp174qRd+6oWfemFdQT3yio+Pp6ysjDNnzuDxeCgsLGT48OFNxiQmJlJQUABAUVERCQkJAAwb\nNowdO3bg8Xg4deoUZWVl9OvXL5jlioiIRQT1yMtut5OVlcX06dMxxjB+/HhiYmLIzc0lPj6exMRE\n0tPTmT9/PklJSXTt2pWcnBwAYmNjGTVqFMnJyYSEhPDSSy/ptKGIiABgM7r+XERELEYrbIiIiOUo\nvERExHIUXiIiYjltJrxut4ZiW1ZeXs6UKVMYPXo0KSkprFu3DoB///vfTJ8+HafTyYwZM6ipqWnl\nSltOQ0MDbrebZ555BoDTp0/z1FNP4XQ6yczMpL6+vpUrbBk1NTXMmTPHd/HToUOH2u28+N3vfseY\nMWNISUlh3rx5eDyedjMvFi5cyGOPPUZKSopv263mwZIlS0hKSiI1NZWjR4+2Rsm31SbCK5A1FNsy\nu93OggUL2LFjB3/4wx/Iz8/nxIkT5OXlMWjQIHbt2sXAgQNZtWpVa5faYtatW0dMTIzv8Wuvvca0\nadPYtWsXYWFhbNy4sRWrazlLly5lyJAh7Ny5ky1btvDAAw+0y3lRUVHB+vXr2bx5M9u2bcPr9VJY\nWNhu5kVaWhpr1qxpsu2r5sGePXsoKyujuLiYRYsW8dJLL7VGybfVJsIrkDUU27KoqCj69u0LQJcu\nXYiJiaGioqLJupFut5v333+/NctsMeXl5ezZs4f09HTftgMHDuB0OoHGXvz5z39urfJaTG1tLX/7\n298YN24cACEhIYSFhbXbedHQ0EBdXR319fX85z//oXv37hw8eLBdzIuHH36Y8PDwJttunAfX/80s\nKSnB5XIB0L9/f2pqarhw4ULLFhyANhFegayh2F6cPn2aY8eO0b9/fy5evEhkZCTQGHCXL19u5epa\nxrJly3j++ed9vwu8fPkyERERdOjQON0dDke7mB+nT5/mnnvuYcGCBbjdbrKysqirq2uX8yI6Oppp\n06YxdOhQnnjiCcLCwoiLiyM8PLzdzYvrLl261GQeXLp0CWi63ix8e9eVbRPhpZ+qNbpy5Qpz5sxh\n4cKFdOnSpV3+qPvDDz8kMjKSvn37+uaFMabZHGkPvamvr+fIkSNMmjSJgoICOnfuTF5eXrv47Deq\nrq6mpKSEDz74gL1791JXV0dpaWmzce2xNzeyyrqyQV1ho6UEsoZiW1dfX8+cOXNITU1lxIgRANx7\n771cuHCByMhIzp8/T7du3Vq5yuD75JNP2L17N3v27OHq1atcuXKFZcuWUVNTQ0NDAx06dAh4nUyr\nczgcOBwO4uPjAUhKSuKtt95ql/Ni37599OrVi65duwIwYsQI/v73v1NdXd3u5sV1XzUPoqOjKS8v\n9437tvalTRx5BbKGYlu3cOFCYmNjycjI8G0bNmwYmzdvBqCgoKBd9CQzM5MPP/yQkpIScnJyGDhw\nIK+99hoDBw6kqKgIaD+9iIyMpEePHnz++edA4/d+sbGx7XJe3HfffRw6dIirV69ijOHAgQM8+OCD\n7Wpe3HhE9VXzYPjw4bz33nsA/OMf/yA8PNx3evHbpM0sD1VaWsrSpUt9ayjOnDmztUtqMR9//DGT\nJ0/mBz/4ATabDZvNxty5c+nXrx/PPfcc586d47777mPFihXNvrRtyz766CPWrl3Lm2++yalTp8jM\nzKS6upq+ffuSnZ1NaGhoa5cYdMeOHePFF1+kvr6eXr16sXz5crxeb7ucFytXrqSwsJCQkBDi4uJY\nsmQJ5eXl7WJezJs3j4MHD1JVVUVkZCSzZ89mxIgR/PKXv7zpPFi0aBF79+6lc+fOLF++nB/96Eet\n/AmaazPhJSIi7UebOG0oIiLti8JLREQsR+ElIiKWo/ASERHLUXiJiIjlKLxERMRyFF5iScOGDeOz\nzz5rkfdauXJlk1tlLFiwgPz8/G+83wULFpCSkkJmZuY33tetHDt2jJ07dwb1PURamsJL5DZWrlzJ\ntWvX7ug+L1y4QHFxMdu2bSMnJ+eO7vtGR44c+Z/Dq6Gh4Q5XI3JnKLykTfn888/52c9+Rnp6Oi6X\ny7f8DUCfPn1YtWoV48eP58knn6S4uNj33K5duxg1ahRpaWmsWrWKPn36UFdXx6JFi7DZbEycOBG3\n201tbS0Ax48fJyMjA6fTyQsvvPCV9bz33nukpKSQmprK7NmzuXTpEleuXCEjI4OrV6/idrt55513\nmrxmy5YtzJo1y/fY6/UyePBg3/qdq1ev5qmnniItLY1nn32WixcvAnDt2jV+85vfkJKSgsvlYvbs\n2VRVVfHGG29w4MAB3G43S5cuBRpXpHG73aSmpjJt2jROnToFNK5K4nK5WLJkCRMnTmTv3r3f5K9D\nJHiMiAUlJiaaTz/9tMm2+vp643a7zcmTJ40xxtTW1hqn0+l7/MMf/tDk5+cbY4z5+OOPzeDBg40x\nxly4cME8+uijpqyszBhjzNtvv2369OljvvzyS9/r6urqfO/zwgsvmEmTJhmPx2M8Ho9JTk42+/bt\na1bj8ePHzeOPP24uXLhgjDHm9ddfN88995wxxpjTp0+bhISEm362uro6k5CQYC5fvmyMMWb37t0m\nIyPDGGPMli1bTFZWlm/shg0bzLx584wxxrzxxhtm9uzZpr6+3hhjfK/fvHmzmTNnju81Fy9eNAkJ\nCebEiRPGGGP+9Kc/mfT0dGOMMQcPHjRxcXHm0KFDN61N5NtCR17SZnzxxRecPHmSzMxMXC4XTz/9\nNNeuXWtyV+3Ro0cDMGDAAM6fP4/H4+HQoUM89NBD9OrVC4Dx48c327e5YRW1ESNGEBoaSmhoKHFx\ncZSVlTV7zcGDBxk6dCj33nsvABMnTmTfvn23/Rx33XUXw4cPZ/v27UDjoqnXbyi5e/du9u/fj8vl\nwuVysWHDBs6dOwc03g5mypQp2O12AN8K6jc6dOgQffv25YEHHgBg3LhxHD16lC+//BKA3r17069f\nv9vWKdKa2sQtUUSgMWC6detGQUHBTZ+32Wx06tQJwHcDQq/X2yyYbnx8Mx07dvT92W63N7mg47/3\nc+N9kK6/7+24XC6WL1/OmDFj+Oijj8jOzvbt89lnnyUtLe2m7xeIm9X134+/853vBLQfkdakIy9p\nM77//e9z1113sWXLFt+2kydPcuXKFaD5P+7XHw8YMIB//vOfvu99/vt7MoC7776bmpqar13PoEGD\n2LNnj+87qT/+8Y889thjzd7/Zh5++GFqa2vJycnhySef9IXusGHD2LBhA9XV1QB4PB6OHTsGQGJi\nIuvWrfNdXHL9Dsl3332377u665/36NGjvlulbN68mbi4OIWWWIqOvMSSbDYbU6dOJSQkxHcksW3b\nNt58802WLl3K2rVr8Xq9REZG8vrrr/tec+M+oPGmfC+//DIzZ87knnvuYejQoYSEhNC5c2cApk2b\nxpQpU+jcuTPr168PuMbY2FgyMzOZOnUqHTp0oFevXixatKjZ+38Vl8tFbm4uGzZs8G1LTU2lqqqK\nyZMnY7PZaGhoYNKkSfTp04eZM2eSk5ODy+WiY8eOfPe732XFihUMGjSINWvW4HK5eOSRR3jxxRd5\n9dVXmTdvHl6vl27duvmO7ESsQrdEEQGuXLlCly5dgMYjkU2bNt2R33KJSHDoyEsEWL9+PUVFRXi9\nXrp27crixYtbuyQRuQUdeYmIiOXogg0REbEchZeIiFiOwktERCxH4SUiIpaj8BIREctReImIiOX8\nH4gzFtcS9o9MAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b20dd690\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(graph_means)\n", + "plt.ylabel('Time (seconds)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (vectorized TF operation)')\n", + "_ = plt.ylim(ymin=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4KZg2WXjbhg5" + }, + "source": [ + "## AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UQJBQWbCbinm" + }, + "outputs": [], + "source": [ + "# Sum written using for loop and converted with AutoGraph\n", + "def sum_all(elements):\n", + " sum_ = 0.0\n", + " length = len(elements)\n", + " for i in tf.range(length): \n", + " sum_ += elements[i][0]\n", + " return sum_\n", + "\n", + "def run_trial(num):\n", + " elements = get_elements(num)\n", + " return sum_all(elements)\n", + " \n", + "ag_means = []\n", + "ag_run_trial = ag.to_graph(run_trial)\n", + "\n", + "for num in range(max_elements):\n", + " with tf.Graph().as_default():\n", + " durations = []\n", + " foo = ag_run_trial(num)\n", + " with tf.Session() as sess:\n", + " for _ in range(burn_ins):\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " for _ in range(trials):\n", + " start = time.time()\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " ag_means.append(np.mean(durations))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 310, + "status": "ok", + "timestamp": 1532448438694, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "DLDOmrRW99v5", + "outputId": "ae0e0573-39db-4004-a064-efc618dbf867" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XdYVFf++PH3DE1AinTEjgULioggNiTYjcZCoiZqjEmM\n6RuzcVc32exPE7O72dTNmmhi1u+abLJqLFFsib1E7GLDjkgbQUSKyDAz5/eHySARcYwMQ/m8nsfn\nce4999zPPXOHzz23nKtRSimEEEKIu9DaOgAhhBC1gyQMIYQQFpGEIYQQwiKSMIQQQlhEEoYQQgiL\nSMIQQghhkTqZMObPn88bb7xh6zBqlQceeICffvrJ6uuZOXMmH330kdXXU1NcuHCBUaNG0a1bN776\n6itbh1OnnT17ljFjxtg6jGpV2e9Jr9czZMgQcnNzq2x9tTJhdO3alfDwcMLDw2nfvj1dunQxT1uz\nZg3PPPMMc+bMsXoc6enphISEYDKZrL6uqlTf/mjfjTW/xy+++IKoqCgOHDjAhAkTbpu/bt06xo0b\nR1hYGJMmTbpt/smTJxk9ejRhYWGMGTOG5OTkcvPfffddoqKi6NGjB+++++49LWttEydOZNmyZdW2\nvo8//pinnnqqwjgiIyMpLS29p/pCQkK4dOnSPS3z1VdfMWLECMLCwujduzeTJk1i7dq191RHVXF0\ndCQ+Pp7PP/+8yuqslQnj0KFDHDx4kIMHD9K4cWPmz59vnvbggw9WWxxKKTQaDfLsY+1mze8xIyOD\n1q1b33G+p6cnkydPZurUqbfNKy0t5fnnn2fkyJHs27ePkSNH8txzz2EwGAD49ttv2bx5M6tXr+b7\n779n69at/O9//7No2drgXr6P7OxsEhMTiYuLKzc9PT2dAwcOoNFo2Lx58z2tX6PR3FP5OXPmsHjx\nYmbOnMnevXvZsWMHv/vd79ixY8cdl7H2344HH3yQFStW3HOyvJNamTBupZS6rdE/+eQTXnvtNaDs\n6HH58uX069ePqKgovv32W44ePcqIESOIjIy8rTeybNkyhg4dSlRUFE899RQZGRkVrnvixIkARERE\nEB4ezpEjR1BKMW/ePB544AF69erFH//4RwoLCytc/urVq0ybNo3u3bsTFRVV7gj010c3t/YK9u7d\nS0xMDF988QU9e/akT58+/Pjjj2zbto1BgwYRFRXF/PnzK1znkiVLWL16NV988QXh4eE8++yz5nkn\nT55kxIgRdO/enenTp6PX683ztmzZwsiRI+nevTvjx4/n1KlTFdYPcO7cOaZMmUJUVBRDhgxh3bp1\ndyxbWb0PPPAACxcuZMSIEXTt2pXXX3+dK1eu8PTTTxMeHs6UKVMoKCgwlz98+DDjxo2je/fujBw5\nkr1795rnTZw4kY8++ojx48cTHh7Ok08+SV5ennkelP8eU1NTmThxIhEREURHRzN9+vQ7bsOmTZt4\n8MEHiYyMZNKkSZw/fx6Axx9/nMTERGbPnk14eDgXL168bdno6GgGDx6Mr6/vbfP27t2L0Whk0qRJ\nODg4MHHiRJRS7NmzB4CVK1cyZcoU/Pz88PPz44knnmDFihUAJCYmVrrsrdauXXvbqZxFixbx3HPP\nATdPbfztb38jNjaW3r1785e//KXcvvHjjz8ycuRIunXrxsCBA9m5cycffPABBw4cYM6cOYSHh/PW\nW28BcPDgQeLj4+nevTsPP/wwhw4dKvcdffDBB4wfP56wsDDS0tJYvnw5/fv3Jzw8nP79+7NmzZoK\nv4Ndu3bRsWNHHB0dy01fuXIlYWFhjB492tw2t67v1h7QihUrePTRRwGYMGECSilGjBhBeHi4eR9e\nsmQJAwcOJCoqiueee47Lly8DN089fvPNN3zwwQdER0fj6OiIRqMhPDycd955567bOHToUMLDwxkw\nYIA56f+yD8TExDB//nx69OhBXFwcq1evLrcd165d45lnniE8PJyxY8eW+7vh7++Ph4cHR44cqbDd\n7pmq5WJjY9Xu3bvLTfvnP/+pXnvtNaWUUmlpaapdu3bqzTffVCUlJWrXrl0qNDRUPf/88yo3N1dl\nZWWp6OhotW/fPqWUUj/88IMaOHCgOn/+vDIajerTTz9VY8eOrXDdaWlpKiQkRJlMJvO0pUuXqoED\nB6q0tDR1/fp19cILL5hj+bX33ntPvfnmm8poNCqDwaD2799vnhcSEqJSU1PNn//4xz+qDz/8UCml\nVGJiourQoYOaN2+eMhgMasmSJapHjx7q1VdfVdevX1dnzpxRoaGh6tKlSxWu99a6bm3Hhx9+WGVn\nZ6tr166pIUOGqG+//VYppdSxY8dUdHS0SkpKUiaTSa1YsULFxsYqvV5/W93Xr19XMTExasWKFcpk\nMqkTJ06oqKgodfbs2dvWfbd6Y2Nj1dixY9WVK1eUTqdT0dHRatSoUerkyZNKr9erSZMmqU8++UQp\npVRWVpaKjIxU27dvV0optXv3bhUZGalyc3OVUkpNmDBBDRgwQF28eFGVlJSoCRMmqPfee++O3+P0\n6dPVZ599ppRSqqSkRB04cKDCtjx//rwKCwtTu3fvVgaDQX3++edqwIABqrS01LzepUuXVrjsrZYs\nWaImTpxYbtq///1v9fTTT5eb9swzz6h///vfSimlunXrpo4cOWKed/ToURUeHm7RsrcqLi5W4eHh\n6uLFi+ZpY8aMUWvXrlVKKfXWW2+pZ599VuXn56uioiI1bdo09f777yullDpy5Ijq1q2b+Teo0+nU\n+fPnK9z2vLw81b17d/X9998ro9Go1qxZo7p3767y8vLM5WNjY9XZs2eV0WhUBQUFKjw8XKWkpCil\nlMrOzjbvR7/2t7/9Tc2ePfu26QMGDFDffPONOnbsmOrYsaO6cuWKed6v41u+fLl69NFHzZ/btWtX\n7je4e/duFRUVZd7/5syZox577DGllFLffPONeuCBByqM7Va/3sbS0lK1detW82913759qkuXLurE\niRNKqbLf+l//+lel1+vV3r17VVhYmLpw4YJS6ubvKTIyUh09elQZjUb16quvqunTp5db57Rp09Ti\nxYvvGpslan0PwxIajYbnn38eR0dHevbsibOzM8OGDaNRo0b4+/sTERHBiRMnAPjf//7H1KlTadmy\nJVqtlqlTp5KcnExmZuYd61e39HDWrFnD5MmTCQoKwtnZmenTp7N27doKz4/b29uTnZ1NWloadnZ2\ndOvWrcI6K+Lg4MC0adOws7Nj6NChXL16lccffxxnZ2dat25N69atK+0FVGTSpEn4+Pjg7u5ObGws\nJ0+eBGDp0qWMGzeO0NBQNBoNI0eOxNHRscKjli1bttCkSRNGjhyJRqOhffv2DBw4kPXr199W1pJ6\nJ0yYgJeXF35+fkRERNClSxdCQkJwcHBgwIAB5hi///57+vXrR58+fYCbR+6dOnVi27Zt5rpGjx5N\ns2bNcHR0ZMiQIeZlf3Frm9vb25Oeno5Op8PR0ZHw8PAK22zdunX069eP6Oho7OzsePLJJ7lx40a5\nI+ff6vr167i5uZWb1rBhQ3OP9dfz3dzcuH79ukXL3qpBgwbExcWZj95TUlK4cOGC+fTOsmXLmDlz\nJm5ubri4uDB16lRz2WXLlhEfH090dDQAfn5+tGzZssLt2bp1Ky1atGD48OFotVqGDRtGq1at2LJl\ni7nMqFGjCA4ORqvVYmdnh52dHadPn6akpAQfHx+Cg4MrrLugoABXV9dy0/bv309GRgZDhgyhY8eO\nNGvW7Laj83uxZs0a4uPjzfvf9OnTOXz4MBkZGVy9evW2XmJMTAzdu3enc+fO5f5+3LqN9vb2xMTE\n0KRJE+BmL7dXr17s37/fXF6j0fC73/0OBwcHunfvTkxMTLle+8CBA+nUqRNarZbhw4fftl+7urqS\nn5//m7f7VvZVUkst4O3tbf5/gwYN8PHxMX92cnIy/9AyMjJ4++23+dvf/gaUnd/W6XQEBgbedT2X\nL1+mcePG5s9BQUEYDAZycnLw8/MrV/app57in//8J1OmTEGj0fDwww9XeC67Ip6enuZzrA0aNKhw\nG3/ZJkvduryzszPZ2dnAzTZZtWqV+S4fpRQGg8HcHb9VRkYGhw8fJjIy0lzWaDQycuTICsverd5b\nY3Jycrrt863f27p168x/fH6p65c/ZEC579zZ2bnS9pkxYwYffvgh8fHx5usMFd2B8+vvW6PREBgY\niE6nu2PdlnJxcbntD3xhYSENGzascH5hYSEuLi4WLftrw4YN4+9//zvPPfcca9asoX///jg6OpKb\nm0txcXG5bTeZTObkmpWVRUxMjEXb8+u2AmjcuHG5tgoICDD/39nZmQ8++ICFCxcya9YsunXrxowZ\nM2jVqtVtdbu7u1NUVFRu2qpVq+jduzceHh7mbVy5ciWPP/64RfFWFH/Hjh3Nn11cXPD09ESn0+Hp\n6Xnb72Hbtm0YjUY6depU7mDk1m38pdy8efNISUnBZDJx48YN2rVrV27bnJyczJ8bN25cbl1326+L\niopwd3f/Tdv8a/UmYVgqICCAZ5991qKL5xVdFPPz8yt3zSM9PR17e/tyX+ovXFxc+MMf/sAf/vAH\nzp07x8SJE+ncuTM9evTA2dmZ4uJic9ns7OzbdrTqEhAQwLRp03jmmWfuWjYwMJCoqCgWLlxYpfVa\nst6RI0cye/bse162ou/R29vbfG3rwIEDPPHEE0RGRtK0adNy5fz8/Dhz5ky5aZmZmVXyXbVp04ZF\nixaVm3b69GnzNZfWrVuTnJxMaGgocPMaVJs2bSpdtqI7tQB69+7NzJkzSU5OJiEhgVmzZgHQqFEj\nnJ2dWbNmzW0HPHDzO7zTnUS/blc/Pz82btxYblpGRgZ9+/a94zK9evWiV69e6PV6PvjgA9544w2+\n/vrr29bVrl07Vq1aZf5cUlLCunXrMJlM9O7dG7h5I0B+fj6nTp2iXbt2uLi4cOPGDfMyvxwg3cmv\nf9vXr18nLy8Pf39/PD09eeuttzh+/Hi5pAK3ny24dRv1ej0vv/wy7777LnFxcWi1Wp5//vlyy+Tn\n53Pjxg3zgWFmZiZt27atNNZbnT9/nieffNLi8pWpF6ek7nZ651bjx49n/vz5nD17FrjZ1a3odAqA\nl5cXWq2W1NRU87Rhw4axaNEi0tLSKCoq4oMPPmDYsGFotbc39datW83Luri4mLvgcPOi95o1azCZ\nTGzfvp19+/ZZvA134+Pjc0+3Cz7yyCN8++23JCUlATd/KNu2bavwCL1fv35cuHCBVatWYTAYKC0t\n5ejRo+YLwb+13rsZMWIEmzdvZufOnZhMJkpKSti7d69FR/oVfY/r1683L+vu7o5Wq63wOxwyZAhb\nt25lz549GAwGFi5ciJOTE2FhYRbFbTKZ0Ov1GAyGcv8HiIyMRKvVsnjxYvR6vbknFhUVBcDIkSNZ\ntGgROp0OnU7HokWLGD16dKXL9ujRo8I47OzsGDRoEH//+9/Jz8+nV69eAOae79y5c8338+t0Onbu\n3AlAfHw8y5cvZ8+ePSil0Ol05u/61/tZTEwMFy9eJCEhAaPRyNq1azl//jyxsbEVxnTlyhU2b95M\ncXEx9vb25t9IRXr16sXx48fNF+N/+OEH7OzsWLduHatWrWLVqlWsXbuWbt26sXLlSuDmb2zjxo3c\nuHGDixcv8t1335Wr89fxP/jggyxfvpzk5GT0ej3vv/8+Xbp0oXHjxrRs2ZKxY8cyffp0du/eTUlJ\nCSaTiYMHD1Z6t1VpaSmlpaU0atQIrVbLtm3b2LVrV7kySik+/vhjSktL2b9/P1u3bmXIkCF3rPNW\nOp2Oa9eu0aVLF4vK302tTxiW3Pr26zKVfe7fvz9PP/00r7zyChEREYwYMeKOt8U1aNCAadOmMX78\neCIjI0lKSiI+Pp6HHnqICRMmMGDAAJydnXn99dcrXD4lJYXJkyfTtWtXxo8fz2OPPUb37t0B+NOf\n/sTmzZvp3r07CQkJ9O/f/7628Vbx8fGcPXuWyMhIXnjhhbuW79SpE3PmzGH27NlERkYyaNCg2+44\n+YWrqytffvkla9eupU+fPvTp04f33nuv3F01ltZ7L9sUEBDAvHnzmD9/PtHR0cTGxvLll1+aDxYq\nW7ai7/Ho0aM8/PDDhIeH8/zzz/OnP/2JoKCg25Zt2bIl7777LnPmzCE6OpqtW7fy2WefYW9vf9f1\nws3TJp07d2b27NkcOHCALl26mB86dXBwYN68eaxYsYLIyEiWL1/OvHnzzHWPGzeO2NhYRowYwYgR\nI4iNjeWRRx6xaNmKDBs2jJ9++okhQ4aUS46///3vad68OY888ggRERFMmTKFlJQUADp37szcuXOZ\nO3cu3bp1Y9KkSebz9ZMmTWL9+vVERUXx9ttv4+npyWeffcbChQvp0aMHCxcuZP78+eZTRr9uK5PJ\nxL///W/69u1Ljx492LdvH2+++WaFsXt7e9OjRw82bdoE3Lw7asyYMfj7++Pt7W3+99hjj7F69WpM\nJhOTJ0/GwcGBXr16MXPmTIYPH16uzhdffJEZM2YQGRnJ+vXriY6O5uWXX+bFF1+kT58+pKWl8f77\n75vL//nPf2bixIm88847REVFERMTw8cff8yHH35oPhX36210dXXlT3/6Ey+//DKRkZGsXbv2tluD\nfX198fDwoE+fPsyYMYPZs2fTokWLO36Pt1q9ejWjRo3CwcHBovJ3o1H3cvh9j2bNmsXWrVvx9va+\n48WmxMRE3nnnHQwGA40aNWLx4sXWCkcIUYedO3eOP/7xjyxdutTWoVSZvXv3MmPGDLZu3XrPy+r1\nekaOHMlXX32Fl5dXlcRj1YSxf/9+XF1dmTFjRoUJo6CggHHjxvHll1/i7+9Pbm5ulW2YEELUdveT\nMKzBqqekIiIiKr06v3r1agYOHIi/vz+AJAshhKjBbHoNIyUlhWvXrjFx4kTGjBljvhglhBDi5s0L\nNaV3ATa+rdZoNHLixAn+7//+j+vXrzNu3Di6du1K8+bNbRmWEEKICtg0Yfj7+9OoUSOcnJxwcnIi\nIiKC5OTkuyaMXx6mE0IIUX2snjAqu6YeFxfHW2+9hdFoRK/Xk5SUxBNPPHHXOjUaDdnZBXctVx/4\n+rpJW/xM2qKMtEUZaYsyvr5udy9UCasmjFdffZXExETy8vLo168fL774IqWlpWg0GsaOHUtwcDC9\ne/dmxIgRaLVaHnnkkUqHghZCCGE7Vr2t1prkiOEmOXoqI21RRtqijLRFmfvtYdT6J72FEEJUD0kY\nQgghLCIJQwghhEUkYQghhLCIJAwhhBAWkYQhhBB1VFXfBCsJQwgh6qD07EJe+HAHe05kVVmdkjCE\nEKIOWrnjAsUlBlwbVM3Lk0AShhBC1GqLN5zi201nyp1+StUVcOB0Nq0au9OpZdW9NsKmgw8KIYT4\n7c6mX2PLoXQA/Bs5ExveBIBVOy8AMLJ3yyodqFUShhBC1FIb9qYC4Oig5ZtNZ2gR6I5Wo+HQmRyC\ng9zpWIW9C5BTUkIIUStdvnqdg6eyaR7gxgujQjEaFZ+uPMaSLWcBGNmnVZW/BkJ6GEIIUcMopZi3\n8hhXC0ro0MKLTi29aNXYHXu7smP8jfsuoYDBkc3o1MqbB3u2YPXuFHKu3aBNEw86NG9U5XFJwhBC\niBrmXEY+B05lA3A+I581u1Nwc3HgyWHt6RzsQ2FxKTuTMvF2dyIixBeAh3q35Gz6NU5evGqV3gVI\nwhBCiBpn19FMAJ4b2Qk7Ow3HLuSy40gmHy5NYlh0c+y0GvQGEwO6N8NOe7PXodVqeGlMZzKuFNEy\n0N0qcUnCEEIIG7mYVcD5jGv06xpk7hHoS43sPamjkZsT4W190Wo1dG3jS9/Ojfl05TESfroIgLOT\nPX06B5arz8nRzmrJAuSitxBC2IS+1Mi/Vhxl8cbTJJ7UmacfOpNDcYmRnp0C0GrLTis1D3Djz5O7\nE9725imouG5BODtV7zG/9DCEEMIGNuxNJefaDQCWbD5Ll2AfnJ3szaejenYKuG0Zlwb2PD+qE5cu\nF9LEt2G1xgvSwxBCiGqXm3+DhD0XcXdxYGD3puQV6lm9K4WrBSUcT8klOMidQG/XCpfVaDQ083cr\n1/uoLlZNGLNmzaJnz54MHz680nJJSUl06NCBjRs3WjMcIYSoEZZtO4e+1MSYmGBG922Fj0cDfth/\nieXbzqEU9AoNvHslNmDVhDF69GgWLlxYaRmTycR7771Hnz59rBmKEELUCGfTrrHnuI7mAW706hyI\no4Md4/u3wWhS7DqWhb2dlsgQP1uHWSGrJoyIiAjc3Su/Yr948WIGDRqEl1fVPsIuhBA1jUkpvtl0\nGoBH+7dB+/OdUWGtfegc7A1AeFsfXKpwhNmqZNNrGDqdjh9//JHx48fbMgwhhKgWe45ncSGzgMj2\nfrRp4mmertFomDCgLV2CvRkW3cJ2Ad6FTe+Smjt3Lq+99pr5/uOqfjuUEELUFCWlRr7bdh4Hey0P\n92t923wfT2defriLDSKznE0TxrFjx3jllVdQSnH16lW2b9+Ovb09cXFxd13W19etGiKsHaQtykhb\nlJG2KFMT2uLbH05xtaCEh+PaENLa19bh/CZWTxiV9Ro2bdpk/v/MmTOJjY21KFkAZGcX3HdsdYGv\nr5u0xc+kLcpIW5SpCW1xtaCEpZtO4+7iQL/OgTaL534Tp1UTxquvvkpiYiJ5eXn069ePF198kdLS\nUjQaDWPHjrXmqoUQosZYseM8+lIT4+PaVPvT2VXJqpG/9957Fpd95513rBiJEEJUv5JSI9sOpbMr\nKZMmvq706dzY1iHdl9qb6oQQwoYMRhMnUq4S0swTRwe7cvOKSwxsOZTOhr2pFFwvxdFBy4SB7Wzy\ndHZVkoQhhBC/wfJt51m/N5UALxeeHNae4CAPlFLsP5XNf388zbVCPc5OdjzYszkDIpri5uJo65Dv\nmyQMIYS4R5lXivhh/yWcnezR5V5n7lcH6N+tKZm5RRw7n4u9nZbhPVswKLJpjX0I77eQhCGEEPfo\n201nMZoUU4aG4ObiyJcJJ/lh/yUAOrb0YsLAtvg3crFxlFVPEoYQQtyDI2dzOHr+Cu2bNyK8rS8a\njYb/NyWSH/ZfIsDLhW7tfK3yetSaQBKGEEJYyGA08e2mM2g1Gsb3b2NODE6OdjzYs4Vtg6sG8j4M\nIYSwgMFoYvm28+iuFhPbNcgmLzCyNelhCCFEJZRSHDqTw9Kt59DlXsfD1ZGH+rS0dVg2IQlDCCF+\n5WpBCeczrnE+I58TF69yMasArUZDv65BPNS7JQ2d686dT/dCEoYQQvys1GDi281n2HIw3TxNA3Rt\n40N8v+A7vja1vpCEIYSol86k5eHsaE9jX1e0Gg05ecXMW3mMlKwCGvu4Et3Rn1aB7rQIdK/V4z9V\nJWkFIUS9s+NIBv9elwxAQ2cH2jTx4PSlPIpuGOjVKYAJg9rh9KvhPoQkDCFEPXMq9Sr/2XAK1wb2\ndGntw6nUqxw6k4O9nZbJQ0Lo0zmwzj5Hcb8kYQgh6g3d1et8svwoAM+PCiWkeSOUUuRcu4GTgx3u\nrrV/vCdrkoQhhKgXim6U8tHSJIpuGJg8JISQ5o2Am+/T9vV0tnF0tYM8uCeEqPOMJhOfrTxGVu51\nBkc2o2+X2v1eCluRhCGEqPOWbjnH8ZSrdA72Jr5fsK3DqbUkYQgh6rQf915k475LBHq78MyIjrX+\nJUa2JAlDCFGrXS0oYfPBNIpulN4270xaHv9aloRrA3teiu8sz1PcJ6u23qxZs9i6dSve3t6sXr36\ntvmrV6/m888/R6PR4OLiwl/+8hfatWtnzZCEEHVIid7I+0sOk55dxIrt5xkW3YK4bkHorhazds9F\nEk/o0Gg0TBvZqU6+n6K6aZRSylqV79+/H1dXV2bMmFFhwjh8+DDBwcG4ubmxfft2PvnkE5YsWWJR\n3dnZBVUdbq3k6+smbfEzaYsy9aEtlFIsWH2CxBM62jdvxMWsAq6XGGjo7EBh8c3eRhNfV6aM6EQL\n3/o9pMcvfH3d7mt5q/YwIiIiSE9Pv+P8sLCwcv/X6XTWDEcIUYf8eCCNxBM6goPceeWRLpSUGkn4\n6SJbDqUTHOTOsOgWdAn2xs/Pvc4nz+pSY07oLV26lL59+9o6DCFELXD6Uh5LNp/F3cWB50aGYm+n\nxd5OyyOxrXkktrWtw6uzakTC2LNnD8uXL+e///2vxcvcb9eqLpG2KCNtUaautsW5tDzmrTyGAv44\nOZK2rXzuukxdbYvqZvOEkZyczJ///Ge++OILPDw8LF5Oupg31Ydz1ZaStihTV9vibPo1PlhyhBsl\nBh4fEkKAu9Ndt7OutsVvUaOvYcDNC1N3kpGRwUsvvcTf//53mjVrZu1QhBC1WPLFq3y0LIlSg4mn\nhncgumOArUOqd6yaMF599VUSExPJy8ujX79+vPjii5SWlqLRaBg7dizz5s3j2rVr/L//9/9QSmFv\nb8+yZcusGZIQopZJyy7kx/1p7D6WiVLw7MhOdGvna+uw6iWr3lZrTdLFvEm622WkLcrU9rYwmkwk\nnb3CjwfSOHnxKgA+Hg2YNLgdnVp631Ndtb0tqlKNPyUlhBCWulpQwq6jmWw9nE5ufgkA7Zs3on+3\nJnRp7SPDetiYJAwhhM2YTIp9yZc5kZLL6Ut56K4WA+DkYEds1yBiuwbRxK+hjaMUv5CEIYSwCX2p\nkc9Xn+DA6WwAGjjaEdrKm87B3vTsFCDjPtVA8o0IIapdwXU9//zuKGfTrxHSzJNHHmhNMz83OeVU\nw0nCEEJUq8t5xXyw5Ai63Ov06ODPE0Pb42AvA2fXBpIwhBDVJj2niH98e4hrhXqG9mjO6JhWaDXS\nq6gtJGEIIapFqq6Af3x7mMLiUsY90JqBkfKwbm0jCUMIYTUmk6KguJSLWQUs+P44xSUGHh/cjpiw\nIFuHJn4DSRhCiCp3IiWX/2w4RU7eDUw/Pxus1WhkSI9aThKGEKJKHTiVzfzvjwHQqrE7Hg0d8XR1\nIrydL+2bN7JxdOJ+SMIQQlRqx5EMDp3JoW9YYzoHe1d6kXrHkQwWrU/G0d6OF8eE0qGFVzVGKqxN\nEoYQ4o5+EQoNAAAgAElEQVRKDSaWbj1HYXEph8/mEOjtQv9uTWjo4oi+1Ii+1Ej+9VKuFtzgyrUb\nHE+5imsDe155JIxWjd1tHb6oYpIwhBB3dOhMNoXFpUS298PeTkviCR2LN56+Y3n/Rs68MDqUIF8Z\nzqMukoQhhLijHUmZAIzo1ZLGPq6M7tuKg6ez0Wg0ONprcXDQ4ubiiJebE54NnWQ4jzpOvl0hRIVy\n8oo5cSGX1kEeNPZxBcDLvQH9I5raODJhK/I8vhCiQjuPZqKAPl0CbR2KqCEkYQghbmMyKXYezaSB\nox3dQ/xsHY6oISRhCCFuczwll9z8EiLb+9PAUc5ci5skYQghbrP9SAYAfbs0tnEkoiaxasKYNWsW\nPXv2ZPjw4Xcs89ZbbzFw4EAeeughTp48ac1whBB3kXmliM9WHePgqWyCfF1pGXh/74AWdYtV+5qj\nR49m4sSJzJgxo8L527ZtIzU1lY0bN3LkyBHefPNNlixZYs2QhKi3SvRGPlt1DEcHO8Ja+xAa7I1r\nA3uuXLtBSlYBh87ksOdEFkpBc383Jg8JQSNDj4tbWDVhREREkJ6efsf5mzZtYuTIkQB06dKFgoIC\ncnJy8PHxsWZYQtRLq3encOTcFQD2JV9Gq9Hg7GRH0Q2DuUwTX1ce6t2K8LY+kizEbWx6Nevy5csE\nBJSNXOnv749Op5OEIUQVy7xSxIa9qXi7N+D50Z04fiGXw2dzKLheSocWXrQIcKNloDttm3nKC43E\nHdk0Yaifhz2+laVHNb6+cm71F9IWZaQtyvzSFkopPvouCaNJMW1MZ7p3CqR7aP16H4XsF1XDpgnD\n39+frKws8+esrCz8/Cy75zs7u8BaYdUqvr5u0hY/k7Yoc2tb7D2p48iZHDoHe9PKz7XetZHsF2Xu\nN3Fa/bbainoRv4iLi2PlypUAHD58GHd3dzkdJUQVKiwu5dtNZ7C30/Jo/zZyXULcF6v2MF599VUS\nExPJy8ujX79+vPjii5SWlqLRaBg7diwxMTFs27aNAQMG4OzszDvvvGPNcISos0r0Rq6XGCg1mjAY\nTBy9mMeWfakcu3AFg1ExolcL/Bq52DpMUctpVGVdgBpMupg3SXe7TH1rC6UUpy/lseVQOgdOZWM0\n3f5TbuLrSlQHfwZFNsPern4+p1vf9ovK3O8pKXnmX4ha6PSlPBZvOEV6ThEAgd4uNPN3w95Og4Od\nlqaBHoQ0cSfQ29XGkYq6RBKGELXMoTPZfLryOCaTIrK9H7Fdg2jb1LPc9Qk5qhbWIAlDiFpk19FM\n/r02GXt7DS+N6UynVt62DknUI5IwhKihDEYTmw+kkX3tBiaT4nqJgcQTOlwb2PO7h7sQHORh6xBF\nPSMJQ4ga6rtt59iw91K5aY3cnHjlkS40kXdmCxuQhCFEDXToTDYb9l7C38uFaSM6Ym+vxU6rwdvd\nCQd7O1uHJ+qpuyaMS5cusWzZMhITE8nKysLJyYmQkBAGDRrEwIEDsbeXnCNEVcrJK2bhmpM42Gt5\nbmQnmvpJb0LUDJX+tf/zn//M8ePHGTx4ML///e/x8fGhpKSEc+fOsXPnThYsWMBf/vIXwsLCqite\nIeo0g9HEp6uOc73EwOQhIZIsRI1SacKIi4tj9uzZt01v164dQ4cOJS8vj0uXLlWwpBDiXhmMJhYm\nnORCZj7RHQPo0znQ1iEJUU6lCSMmJqbShT09PfH09KzSgISoj0oNRj5deZzDZ3MIDnJn4qC2Mu6T\nqHEsGivgr3/9KwUFBRgMBh599FHCwsJYtWqVtWMTol64oTfw4dIkDp/NoUOLRvx+bFcaOMq1QVHz\nWJQwdu/ejZubGzt37sTf358NGzbw5ZdfWjs2Ieq8VF0Bf/3qICcvXqVrGx9eju+Mk6PcBSVqpns6\njNm3bx8DBgzA399fustC3IdSg5Hvd6Wwbk8qJqWICWvMhIFtsdPWzwECRe1gUcLw9vbm9ddfZ9eu\nXUydOhWDwYDRaLR2bELUOfpSI3tO6Fi35yK6q8V4uzfg8cHtZIgPUStYlDDee+89vv/+e+Lj4/Hw\n8CAtLY0nnnjC2rEJUWcUFpeycV8qWw9lUFhcilajoX9EE0b3bSXXK0StYdGe6uXlxeTJk82fmzRp\nQpMmTawVkxA12rXCErRaDW4ujhaV1129zgdLjnD5ajGuDewZFt2c2K5BeLk3sHKkQlStShPGc889\nx7Rp0+jcufNt8woLC/nuu+9o0KABY8eOtVqAQtQkJXojbyzcyw29gZ6dAhkU2bTSd06cS7/GR8uS\nKCwuZUiPZozo1RInB7moLWqnShPGSy+9xHvvvUdKSgqdO3fG29ubkpISzp8/T3p6OuPGjWP8+PHV\nFasQNrc3WUdhcSmODlq2H8lgx5EMQoO96djSi3ZNPWni2xC9wYgut5jzGdf43+azlBpNTBrcjn5h\nQbYOX4j7UmnCCAkJ4fPPPyczM5O9e/ei0+lwcnJi8ODBdOvWDUdHy7rkQtQVO45kogFmPxlFalYB\n6xIvknTuCknnrgDgaK9FbzCZyzs6aHlpTGe6tPaxUcRCVB2LrmEEBgby0EMP/aYVbN++nblz56KU\nYsyYMUydOrXc/MzMTP7whz9QUFCAyWRi+vTpd33CXAhbSM8p4mz6NTq29MLP0xk/T2e6tfMl+9oN\nzlzK49SlPFIyC3B3dcDfy4WARi6EBnsT4OVi69CFqBIWJYwrV67wzjvvkJmZyddff01ycjKHDh26\n6+kok8nEnDlzWLRoEX5+fsTHxxMXF0dwcLC5zKeffsrQoUMZN24c586d4+mnn2bz5s33t1VCWMGO\nIxkA9O3S2DxNo9GYk0evUBn7SdRtFj0l9Prrr9OtWzfy8/MBaNWqFf/973/vulxSUhLNmzcnKCgI\nBwcHhg0bxqZNm8qV0Wg0FBYWApCfn4+/v/+9boMQVldqMLH7WBYNnR3o2kZOL4n6yaKEodPpGD9+\nPHZ2N+/ucHR0RGvBE6k6nY7AwLKjLn9/fy5fvlyuzAsvvMCqVauIiYlh2rRpvPHGG/cSvxDV4vDZ\nHAqLS+kVGoC9nTyNLeoni05J/folSfn5+Sil7rqcJWUSEhIYM2YMkydP5vDhw7z22mskJCTcdTlf\nX7e7lqkvpC3KWKst9iw/CsBD/drUmvauLXFWB2mLqmFRwhg4cCB//vOfKSoqYvny5fz3v/9lzJgx\nd10uICCAjIwM82edToefn1+5MsuWLWPhwoUAhIWFUVJSQm5uLl5eXpXWnZ1dYEnodZ6vr5u0xc+s\n1RapugIOn86mdRMPGmhrx74n+0UZaYsy95s4LepbP/XUU0RERNCxY0e2bdvGxIkTefzxx++6XGho\nKKmpqaSnp6PX60lISCAuLq5cmcaNG7N7924Azp07h16vv2uyEKI6KKXYcjCNtxcfQAEDIpraOiQh\nbEqjLDlvdB+2b9/O22+/jVKK+Ph4pk6dyscff0xoaCixsbGcO3eO119/nevXr6PVapkxYwbR0dF3\nrVeOGG6So6cyVdkW+UV6Fq1L5vDZHFwb2DN5SHu6tfOtkrqrg+wXZaQtytxvD8OihHHlyhW++uor\nUlNTMRgM5ukfffTRfa38fsgOcJP8GMpURVvkX9ezYW8qmw+kU1JqJKSZJ08P70gjN6cqirJ6yH5R\nRtqizP0mDIuuYTz33HN06NCB6Oho851SQtQl6dmF7EjKZNvhDEpKjXg0dCS+XzCxXYPQauXdL0KA\nhQmjuLiYN99809qxCFGtSkqNbD+cwe5jWVzU3TwC9WjoyJiYVsSENcbBXg6OhLiVRQmjS5cunDp1\ninbt2lk7HiGqRfLFqyxal8zlvGLstBq6BHvTMzSQsNbekiiEuAOLEsa4ceOYMGECAQEBODmVnctd\ntmyZ1QIToqqUGkwUXNejN5go0RvZejidbYcz0GhgUGRThkQ1x91VBtIU4m4sShivvfYa06ZNo0OH\nDnINQ9Qql/OK+etXB8gr1Jeb3sTXlSeGtqdloLuNIhOi9rEoYTg5OfHkk09aOxYhqpS+1Mi85UfJ\nK9TTtY0Pbi4OONrb4e/lQkxYYxniQ4h7ZFHC6NOnD9u3b6dv377WjkeIKqGUYvHGU6ReLqRvl8ZM\nHhJi65CEqPUsShhLlixhwYIFuLq64ujoiFIKjUbDTz/9ZO34hPhNth/JYNfRLJoHuPHYgDa2DkeI\nOsGihPHdd99ZOw4hqkROXjE/ndCxetcFXBvY8/yoTnLXkxBVxKKEERQk7yIWNVvSuRx+WHKE4+d/\nflWqg5ZnHuqIj4ezjSMTou6oNGG89tprvPvuu4wZMwaN5vanXeW2WmFrJqX4fucFvt+VAkBIM0+i\nOwbQrZ0fLg0sOh4SQlio0l/ULy87+sMf/lAtwQhxL4pLDCxMOMnB09n4eDTgjSd74OYodz4JYS2V\nJoxfXskaGRlZLcEIYYkSvZHEkzrWJ6aSlXudkGaePDuyE62CPGSQOSGsSPrsotbIKywh4aeL7D6W\nSXGJEY0G+ndrwiMPtJZnKoSoBpUmjNOnT1f4bgq5rVZUtwOnsvm/9ckUFpfi2dCRARFN6dulMV7u\nDWwdmhD1RqUJo0WLFixYsKC6YhHiNsUlBr758Qw7j2biYK/lsQFt5SltIWyk0oTh6Ogot9QKmzAY\nTWw7nMGa3SlcK9LT3N+NqSM6EOjtauvQhKi3Kk0YDg4O1RWHEMDN22R/OpbFqp0XyLl2AycHOx7q\n3ZJh0c2lVyGEjVWaMJYsWVJdcQhBenYh/7fhFGfTrmFvp2Vg96YM7SFDjwtRU1j9Lqnt27czd+5c\nlFKMGTOGqVOn3lZm7dq1/Otf/0Kr1dKuXTv+8Y9/WDssUYPkF+n5Yf8l1iemYjQpurXzZdwDbfD2\nkAvaQtQkVk0YJpOJOXPmsGjRIvz8/IiPjycuLo7g4GBzmYsXL/LFF1/wv//9j4YNG5Kbm2vNkISN\nGIwmVu9KQW8w4uxkj4uTPVfyb3Ai5SqXLhcC4O3uxGMD2hHWxsfG0QohKmLVhJGUlETz5s3NF86H\nDRvGpk2byiWMJUuW8Oijj9KwYUMAvLy8rBmSsJENe1NZvTvltun2dlo6tGhEp5be9OvamAaO8miQ\nEDWVVX+dOp2OwMBA82d/f3+OHj1arkxKSgoA48ePRynF888/T58+fawZlqhmV67dYPWuFNxdHHhh\ndGdKDEau3zDg2sCe1kEeODrIaLJC1AZWTRhKqbuWMRqNpKam8vXXX5ORkcFjjz1GQkKCucdxJ76+\nblUVZq1X09vi84ST6A0mnn+4C9Fdm1h1XTW9LaqTtEUZaYuqYdWEERAQQEZGhvmzTqfDz8+vXBl/\nf3+6du2KVqulSZMmtGzZkpSUFDp16lRp3TJm0E2+vm41ui2Onr/CT0czadPEg07NPK0aa01vi+ok\nbVFG2qLM/SZOq97YHhoaSmpqKunp6ej1ehISEoiLiytXpn///uzZsweA3NxcLl68SNOmTa0Zlqgm\npQYTX/9wGq1Gw4SB7SocIl8IUXtYtYdhZ2fHG2+8wZQpU1BKER8fT3BwMB9//DGhoaHExsbSp08f\ndu3axbBhw7Czs2PGjBl4eHhYMyxhRUopUnWFHDydzcHT2Vy+Wkz/iCY09av8FKMQoubTKEsuNNRA\n0sW8qSZ1t3OuFfOvFce4mHUzHns7LWGtvXliaHucnax/91NNagtbk7YoI21R5n5PSck9jKJKnEnL\n41/Lj5J/vZSubXyI7hhAp1ZecpusEHWI/JrFfTEpxc6kTBZvOIVS8NiAtjwQHiTXK4SogyRhiHtm\nMilOpl7l4Kmb1ymuFelxcbLn2VGd6NhCHrwUoq6ShCHuSeaVIr5Yc5ILmTdf39vQ2YHenQMZFt0c\n/0YuNo5OCGFNkjCERUxKselAGsu2nqPUYKJ7iB+xXYNo09QDO60MOy5EfSAJQ9xVzrVivkw4SXJq\nHg2dHXj6wQ5EhPjdfUEhRJ0iCUPckVKKHUmZfLvpDDf0RsJa+/D44HZ4NHSydWhCCBuQhFHPGYwm\ntBoNWm35u5oycopYsuUsSeeu4Oxkx5PD2tOzU4Dc/SREPSYJox7LL9Lz1n/2U1JqJCLEj6j2/rg2\nsGf17hT2nbyMAjq2aMQTQ9vj5S4vMxKivpOEUU+ZlOLzNSduvjfb0Y4tB9PZcjDdPL+ZX0OG92pJ\neFsf6VUIIQBJGPXW2p8ucvxCLqGtvHlxTCinUvNIPKEjr6iE2K5BhLWWRCGEKE8SRj10+lIeK3ac\np5GbE0892B57Oy0dW3rRsaU8dCeEuDNJGPWIUoqz6df4bNUxNGh4ZkRH3FwcbR2WEKKWkIRRD5Qa\njOw/lc0P+y6R8vNIsg/3C6ZtU08bRyaEqE0kYdRRpy/lcfB0NmfTr3ExqwCjSaEBurbxYWD3prRr\n1sjWIQohahlJGHXQ0fNX+HDJERRgp9XQzL8hIc0aEdM1CD9PZ1uHJ4SopSRh1DGX84pZ8P1x7Oy0\nPPtQRzq09MLJwc7WYQkh6gBJGHVIid7IJ98dpeiGgSeGhNC1ra+tQxJC1CEyzGgdoZRi0fpk0rIL\n6dc1iD5dGts6JCFEHWP1hLF9+3YGDx7MoEGDWLBgwR3LrV+/npCQEI4fP27tkOqU4hIDWw+n8+aX\n+0g8oSO4sTuP9m9j67CEEHWQVU9JmUwm5syZw6JFi/Dz8yM+Pp64uDiCg4PLlSsqKuKrr74iLCzM\nmuHUKSaTImHPRTbsTeX6DQN2Wg0RIX481r8N9nbScRRCVD2rJoykpCSaN29OUFAQAMOGDWPTpk23\nJYyPPvqIp59+mi+++MKa4dQZeYUlLPj+OMmpeXi6OdG/WxNiwoJo5CbDjgshrMeqh6I6nY7AwEDz\nZ39/fy5fvlyuzMmTJ8nKyiImJsaaodQZxy5c4S9f7iU5NY+w1j7Mm/EAI/u0kmQhhLA6q/YwlFJ3\nnT937lz+9re/WbzML3x93e4rttomJTOfxWtPsvdEFvZ2Gp5+qBPD+7RCo9HI8B63qG/7RWWkLcpI\nW1QNqyaMgIAAMjIyzJ91Oh1+fmWv9iwqKuLs2bNMnDgRpRQ5OTk899xzfPrpp3Ts2LHSurOzC6wW\nd02Sqitgw95U9hzXoYC2TTwY178NLQLcyckpxNfXrd60xd1IW5SRtigjbVHmfhOnVRNGaGgoqamp\npKen4+vrS0JCAu+//755fsOGDfnpp5/MnydOnMjMmTPp0KGDNcOq8fSlRvaevMzWw+mcz8gHbr6f\nYnRMMKGtvGTYcSGETVg1YdjZ2fHGG28wZcoUlFLEx8cTHBzMxx9/TGhoKLGxseXKazQai09J1UUm\nk2LXsUxWbD9PXqEeDdA52JuYsMZ0ae2DVhKFEMKGNKqW/oWua13Mkym5fLv5LJcuF+JoryUuogmx\nXYPw8ah87CfpbpeRtigjbVFG2qJMjT4lJe5OX2pkyZazbD6Yjgbo1SmAUX1byTu0hRA1jiQMG0q7\nXMj874+TnlNEkI8rTz7YnhYB7rYOSwghKiQJwwZMSrHpQBpLt5zDYDTxQHgQj8S2xlFGlRVC1GCS\nMKpZbv4Nvlx7khMpV2no7MATQzvStY2MKiuEqPkkYVQTfamR3ceyWLb1HNdLDHQO9uaJISF4NJQn\ntIUQtYMkDCvLyStm86F0dhzJoOiGAUcHLZMGtyOmS2N5nkIIUatIwrASk0mxLvEiK3dcwGhSuLk4\n8GDP5vQLC5I7oIQQtZIkDCu4cu0GX6w5walLeXg0dCQ+JpjI9n442MtFbSFE7SUJowoppdhzQsdX\nG09TXGIgvK0vjw9uJ4MDCiHqBEkYVSS/SM9/Npzi4OlsnBzsmDwkhD6dA+U6hRCizpCEcZ+UUuxL\nvsxXG09TWFxK26aeTBnWHj/Pyof0EEKI2kYSxn1Iu1zIf388TXJqHg72WsbFtaF/RBMZJFAIUSdJ\nwvgNCotLWbnjPFsOpaMUdAn2ZlxcG/y9XGwdmhBCWI0kjHtgMim2HU5n+fbzFN0wEODlwvj+bQht\n5W3r0IQQwuokYVjoVOpV/vvjGS5dLqSBox2PxLamf0QT7O2s+lp0IYSoMSRh3EVu/g2WbDnL3pOX\nAegdGsiYmFYypIcQot6RhHEHJXojG/elkrDnIvpSEy0D3Xh0QFuCG3vYOjQhhLAJSRi/YjCa2JmU\nyaqdF7hWpMfNxYHH+relV+dAuftJCFGvWT1hbN++nblz56KUYsyYMUydOrXc/EWLFrF06VLs7e3x\n8vJi7ty5BAYGWjus2xiMJhJP6Fjz00V0uddxdNDyYM8WDI5shksDyatCCGHVv4Qmk4k5c+awaNEi\n/Pz8iI+PJy4ujuDgYHOZDh06sHz5cpycnPjmm2/4+9//zgcffGDNsMopNZjYmZTBusRUcq7dwE6r\noV/XIEb0aoGnXKcQQggzqyaMpKQkmjdvTlBQEADDhg1j06ZN5RJGZGSk+f9hYWGsXr3amiGVk3ml\niM9WHefS5UIc7LXEhTdhcFQzvD1kNFkhhPg1qyYMnU5X7vSSv78/R48evWP5ZcuW0bdvX2uGZLbr\naCZfbTxNSamRPp0DGd1X7nwSQojKWDVhKKUsLrtq1SqOHz/O4sWLqzyOkxev8p/1yZQaTTg72qPR\naEjLLsTZyY5pD3Uksr1/la9TCCHqGqsmjICAADIyMsyfdTodfn5+t5XbvXs3CxYs4KuvvsLBwcGi\nun193Swqd+R0Nh8tS8JkMuHl4Uz+dT3Xbxjo2Mqb343rSoC3q2UbU4NZ2hb1gbRFGWmLMtIWVcOq\nCSM0NJTU1FTS09Px9fUlISGB999/v1yZEydO8Oabb7Jw4UIaNWpkcd3Z2QV3LXPswhX++d1RlFK8\nMDqUzsE+wM2ej0ajAZPJonpqMl9ft1q/DVVF2qKMtEUZaYsy95s4rZow7OzseOONN5gyZQpKKeLj\n4wkODubjjz8mNDSU2NhY3n33XYqLi3n55ZdRStG4cWPmzZt33+s+fCaHeSuPAfDimM7lxnuSd1QI\nIcS906h7udBQg2RnF1BSamRnUiYdWjQi8OdTS0op1iemsmzrOezttbw4JpROLevu4IBy9FRG2qKM\ntEUZaYsyNbqHYW3fbTvHj/vT0AARIX4MimzGjwcusee4jkZuTrwwOpSWge62DlMIIeqEWpswUnUF\nbDqQho9HA1ydHdiXfJl9yTcHCAxu7M7zo0PlwTshhKhCtTJhmEyKxRtPoRRMGtyOji28OHYhlw17\nU/Fv5MK4uDY42Muw40IIUZVqZcL4cV8q59LziQjxM1+fCG3lLS8yEkIIK6qVh+GL1pzAydGO8XFt\nbB2KEELUG7UyYRRc1zOyd0sauck1CiGEqC61MmE82Lslcd2a2DoMIYSoV2plwnhmVGd5l7YQQlQz\n+asrhBDCIpIwhBBCWEQShhBCCItIwhBCCGERSRhCCCEsIglDCCGERSRhCCGEsIgkDCGEEBaRhCGE\nEMIikjCEEEJYRBKGEEIIi1g9YWzfvp3BgwczaNAgFixYcNt8vV7PK6+8wsCBAxk7diwZGRnWDkkI\nIcRvYNWEYTKZmDNnDgsXLmTNmjUkJCRw7ty5cmWWLVuGh4cHGzdu5PHHH+fdd9+1ZkhCCCF+I6sm\njKSkJJo3b05QUBAODg4MGzaMTZs2lSuzadMmRo0aBcCgQYP46aefrBmSEEKI38iqCUOn0xEYGGj+\n7O/vz+XLl8uVuXz5MgEBAQDY2dnh7u5OXl6eNcMSQgjxG1g1YSil7rmMUgqNRmOtkIQQQvxG9tas\nPCAgoNxFbJ1Oh5+f321lsrKy8Pf3x2g0UlhYiIeHx13r9vV1q/J4aytpizLSFmWkLcpIW1QNq/Yw\nQkNDSU1NJT09Hb1eT0JCAnFxceXKxMbGsmLFCgDWr19Pjx49rBmSEEKI30ijLDlvdB+2b9/O22+/\njVKK+Ph4pk6dyscff0xoaCixsbHo9Xpee+01Tp48iaenJ++//z5Nmsj7uoUQoqaxesIQQghRN8iT\n3kIIISwiCUMIIYRFJGEIIYSwSK1LGHcbm6ouy8rKYtKkSQwdOpThw4fzn//8B4Br164xZcoUBg0a\nxJNPPklBQYGNI60eJpOJUaNGMW3aNADS0tJ45JFHGDRoENOnT8dgMNg4wupTUFDASy+9xJAhQxg2\nbBhHjhypl/vFokWLePDBBxk+fDivvvoqer2+Xu0Xs2bNomfPngwfPtw8rbL94K233mLgwIE89NBD\nnDx58q7116qEYcnYVHWZnZ0dM2fOZO3atXz77bd8/fXXnDt3jgULFhAdHc2GDRuIiopi/vz5tg61\nWvznP/8hODjY/Pkf//gHTzzxBBs2bMDNzY1ly5bZMLrq9fbbbxMTE8O6detYtWoVrVq1qnf7hU6n\nY/HixSxfvpzVq1djNBpJSEioV/vF6NGjWbhwYblpd9oPtm3bRmpqKhs3bmT27Nm8+eabd62/ViUM\nS8amqst8fX1p3749AK6urgQHB6PT6cqNxzVq1Ch+/PFHW4ZZLbKysti2bRsPP/ywedqePXsYNGgQ\ncLMdfvjhB1uFV60KCwvZv38/Y8aMAcDe3h43N7d6uV+YTCaKi4sxGAzcuHEDPz8/EhMT681+ERER\ngbu7e7lpv94PfvmbuWnTJkaOHAlAly5dKCgoICcnp9L6a1XCsGRsqvoiLS2N5ORkunTpwpUrV/Dx\n8QFuJpWrV6/aODrrmzt3LjNmzDAPI3P16lU8PDzQam/u0gEBAfVm30hLS6NRo0bMnDmTUaNG8cYb\nb1BcXFzv9gt/f3+eeOIJ+vXrR9++fXFzc6NDhw64u7vXy/3iF7m5ueX2g9zcXKD8OH5ws/10Ol2l\nddWqhCGPjNxUVFTESy+9xKxZs3B1da13Y29t3boVHx8f2rdvb94nlFK37R/1pV0MBgMnTpzg0Ucf\nZcWKFTg7O7NgwYJ6s/2/yM/PZ9OmTWzZsoUdO3ZQXFzM9u3bbytX39rlTir6e3q3trHqWFJVzZKx\nqcpAbb4AAAdwSURBVOo6g8HASy+9xEMPPUT//v0B8Pb2JicnBx8fH7Kzs/Hy8rJxlNZ18OBBNm/e\nzLZt2ygpKaGoqIi5c+dSUFCAyWRCq9WSlZVVb/aNgIAAAgICCA0NBWDgwIF8/vnn9W6/2L17N02b\nNsXT0xOA/v37c+jQIfLz8+vlfvGLO+0H/v7+ZGVlmctZ0ja1qodhydhUdd2sWbNo3bo1jz/+uHna\nAw88wPLlywFYsWJFnW+T6dOns3XrVjZt2sT7779PVFQU//jHP4iKimL9+vVA/WiHX/j4+BAYGMiF\nCxeAm9dyWrduXe/2i8aNG3PkyBFKSkpQSrFnzx7atGlT7/aLX/cc7rQfxMXFsXLlSgAOHz6Mu7u7\n+dTVndS6oUEqGpuqvjhw4AATJkygbdu2aDQaNBoNr7zyCp07d+Z3v/sdmZmZNG7cmI8++ui2C191\n1d69e/nyyy/57LPPuHTpEtOnTyc/P5/27dvz7rvv4uDgYOsQq0VycjJ/+tOfMBgMNG3alHfeeQej\n0Vjv9otPPvmEhIQE7O3t6dChA2+99RZZWVn1Zr949dVXSUxMJC8vDx8fH1588UX69+/Pyy+/XOF+\nMHv2bHbs2IGzszPvvPMOHTt2rLT+WpcwhBBC2EatOiUlhBDCdiRhCCGEsIgkDCGEEBaRhCGEEMIi\nkjCEEEJYRBKGEEIIi0jCEDXaAw88wNmzZ6tlXZ988km5oa9nzpzJ119/fd/1zpw5k+HDhzN9+vT7\nrqsyycnJrFu3zqrrEPWbJAwhfvbJJ59QWlpapXXm5OTw/9u7v5AmuziA49/ln7S8KOvWoghaI8KL\nihkJWon0R/Y8S2NYOL1IEFqE3gjRRZZEBcPyJqE/lDSIyBp2UV4IEVgGXeyiDKMVFnSRltTmaPr4\ney/Eh3KL9vYG7+vb73O182znnN/DYL+dHfY7fX199Pb2EgwGf+vYcz1//vyXE8b09PRvjkb9H2nC\nUPPS69evOXjwIDU1NRiGYZc+AHA6nXR1dVFdXU1FRQV9fX32c/fv32fnzp14vV66urpwOp0kEgna\n2tpwOBz4fD5M0yQWiwEwPDyM3++nsrKS1tbWH8Zz584dqqqq8Hg8BAIBPn78SDwex+/38/XrV0zT\n5OrVq9/1CYfDHDp0yG5blkVpaaldL+3ixYvs27cPr9dLU1MTY2NjAExOTnL69GmqqqowDINAIMD4\n+DidnZ08fvwY0zRpb28HZiojmKaJx+OhoaGBt2/fAjP/kDcMg5MnT+Lz+Xj48OE/eTvUn0KU+g8r\nLy+Xly9ffndtampKTNOUaDQqIiKxWEwqKyvt9tq1a+X69esiIvL06VMpLS0VEZHR0VHZvHmzjIyM\niIjIlStXxOl0ysTEhN0vkUjY87S2tkptba0kk0lJJpOye/duGRgYSIlxeHhYtm7dKqOjoyIi0tHR\nIUeOHBERkXfv3onb7U57b4lEQtxut3z69ElERPr7+8Xv94uISDgclmPHjtmvDYVC0tLSIiIinZ2d\nEggEZGpqSkTE7t/T0yOHDx+2+4yNjYnb7ZZXr16JiMjNmzelpqZGREQGBwfF5XJJJBJJG5tS6egK\nQ807b968IRqN0tzcjGEY7N+/n8nJye9OX9y1axcAxcXFfPjwgWQySSQSYf369RQVFQFQXV2dMrbM\nqZSzY8cOcnJyyMnJweVyMTIyktJncHCQsrIyli1bBoDP52NgYOCn95GXl8f27du5e/cuMFMYbvYQ\npP7+fh49eoRhGBiGQSgU4v3798BMefe6ujqysrIA7Oqsc0UiEdatW8fq1asB2Lt3L0NDQ0xMTACw\ncuVKNmzY8NM4lZo1r8qbKwUzH+qFhYXcvn077fMOh4OFCxcC2AfnWJaVkgzmttPJzc21H2dlZaU9\nD1pEUs4RmJ33ZwzD4NSpU+zZs4cnT55w9uxZe8ympia8Xm/a+TKRLq5v24sWLcpoHKVm6QpDzTur\nVq0iLy+PcDhsX4tGo8TjcSD1A3W2XVxczLNnz+zf8b/d9wAoKCjgy5cvfzuekpISHjx4YO8x3Lhx\ngy1btqTMn87GjRuJxWIEg0EqKirsRLdt2zZCoRCfP38GIJlM8uLFCwDKy8u5du2avUE/e5JeQUGB\nvfcye79DQ0N22fOenh5cLpcmCvXLdIWh/tMcDgf19fVkZ2fb35h7e3u5cOEC7e3tXL58GcuyWL58\nOR0dHXafuWPAzEEyx48fp7GxkaVLl1JWVkZ2djb5+fkANDQ0UFdXR35+Pt3d3RnHuGbNGpqbm6mv\nr2fBggUUFRXR1taWMv+PGIbB+fPnCYVC9jWPx8P4+DgHDhzA4XAwPT1NbW0tTqeTxsZGgsEghmGQ\nm5vLihUrOHfuHCUlJVy6dAnDMNi0aRNHjx7lzJkztLS0YFkWhYWF9gpGqV+h5c3VHyUej7N48WJg\n5hv3rVu3fst/LZT6E+gKQ/1Ruru7uXfvHpZlsWTJEk6cOPFvh6TUvKErDKWUUhnRTW+llFIZ0YSh\nlFIqI5owlFJKZUQThlJKqYxowlBKKZURTRhKKaUy8hf8CwfjbzhfpQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b218dbd0\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(ag_means)\n", + "plt.ylabel('Time(s)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (AutoGraph)')\n", + "_ = plt.ylim(ymin=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "d7IAJ6Bwbk9t" + }, + "source": [ + "## Eager" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "XMu5-12yoOzY" + }, + "outputs": [], + "source": [ + "from tensorflow.python.eager import context" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "_vt9MzpyjQ4T" + }, + "outputs": [], + "source": [ + "# Sum written using for loop and run with tf.eager\n", + "def sum_all(elements):\n", + " sum_ = 0.0\n", + " length = elements.shape[0]\n", + " for i in tf.range(length): \n", + " sum_ += elements[i][0]\n", + " return sum_\n", + "\n", + "eager_means = []\n", + "for num in range(max_elements):\n", + " with context.eager_mode():\n", + " durations = []\n", + " for i in range(trials + burn_ins):\n", + " \n", + " start = time.time()\n", + " for _ in range(batches):\n", + " run_trial(num)\n", + " \n", + " if i \u003c burn_ins:\n", + " continue\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " eager_means.append(np.mean(durations))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 422, + "status": "ok", + "timestamp": 1532460024499, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "5gHVdMlD-A-T", + "outputId": "3b581cb7-7ef9-489c-92f1-3e52c0c2dc8a" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XlclNX+wPHPsC+CILviCiruhAiiqaAimqnglpZLWpkt\ntmh50+rXvZZ2q9veq5ulZYtlaZp7mQuaorjklkoqiwgKsojsy8yc3x9cBxFUVGBYvu+/mGfmPM93\nzjzMd85zznOORimlEEIIISphYuwAhBBC1F2SJIQQQtyQJAkhhBA3JElCCCHEDUmSEEIIcUOSJIQQ\nQtyQJAlg8eLFvPrqq8YOo14ZOHAge/furfHjzJs3jw8//LDGj1NXxMfHExERQc+ePfnuu++MHU6D\ndvbsWcaMGWPsMCo1btw4YmNjjR0G0EiSxD333IOfnx9+fn506tSJHj16GLZt2LCBxx9/nNdff73G\n40hOTsbHxwe9Xl/jx6pOje2L+lZq8nNcsmQJgYGBHDp0iEmTJlV4fvPmzUyYMAFfX1+mTJlS4flT\np04xevRofH19GTNmDDExMeWef+eddwgMDKR379688847t1W2pk2ePJlVq1bV2vE++ugjHn30UcPj\ngQMH0qNHD/z8/AzfD2+88UatxXOtRx55pM78zzWKJHH48GH+/PNP/vzzT5o3b87ixYsN2+6///5a\ni0MphUajQe5frN9q8nO8cOEC3t7eN3zewcGBhx9+mBkzZlR4rqSkhKeeeorw8HAOHDhAeHg4Tz75\nJFqtFoAVK1awfft21q9fz7p164iMjOTHH3+sUtn64HY+j7S0NKKjoxk0aFC57YsXL+bPP/80fD+8\n8sor1R3mTel0OqA0YUVHR5Oenl6rx69Mo0gS11JKVTiZPvnkE1588UWg7Ffi6tWrCQ4OJjAwkBUr\nVnD8+HFGjhxJQEBAhVbHqlWruO+++wgMDOTRRx/lwoULlR578uTJAPj7++Pn58fRo0dRSvHpp58y\ncOBA+vbty0svvURubm6l5S9fvszMmTPp1asXgYGB5X5p+vj4cP78ecPja3/979+/nwEDBrBkyRL6\n9OlDv3792Lp1Kzt37iQsLIzAwEAWL15c6TF/+ukn1q9fz5IlS/Dz8+OJJ54wPHfq1ClGjhxJr169\nmD17NsXFxYbnduzYQXh4OL169WLixIn8/fffle4fIDY2lunTpxMYGMiwYcPYvHnzDV97s/0OHDiQ\npUuXMnLkSO655x5eeeUVMjIyeOyxx/Dz82P69Onk5OQYXn/kyBEmTJhAr169CA8PZ//+/YbnJk+e\nzIcffsjEiRPx8/PjkUceISsry/AclP8cExMTmTx5Mv7+/gQFBTF79uwbvodt27Zx//33ExAQwJQp\nU4iLiwNg6tSpREdHs2DBAvz8/Dh37lyFskFBQQwdOhQXF5cKz+3fvx+dTseUKVMwNzdn8uTJKKXY\nt28fAL/88gvTp0/H1dUVV1dXpk2bxpo1awCIjo6+adlrbdq0qcJlmmXLlvHkk08CUFxczFtvvUVI\nSAj33nsv//znP8udG1u3biU8PJyePXsyZMgQdu/ezfvvv8+hQ4d4/fXXy/2C//PPPxk7diy9evVi\n3LhxHD58uNxn9P777zNx4kR8fX1JSkpi9erVDB48GD8/PwYPHsyGDRsq/Qz27NlDly5dsLCwKLf9\nRonm/PnzTJ06lcDAQIKCgnjhhRfK/Z+eOHHCcJnw2Wef5fnnny/XErjVefvFF18Yzlu9Xo+FhQVd\nunRh9+7dlcZTq1QjExISoqKiospt+/jjj9WLL76olFIqKSlJdezYUb322muqqKhI7dmzR3Xr1k09\n9dRTKjMzU6WkpKigoCB14MABpZRSv//+uxoyZIiKi4tTOp1O/fe//1UPPPBApcdOSkpSPj4+Sq/X\nG7atXLlSDRkyRCUlJan8/Hz19NNPG2K53rvvvqtee+01pdPplFarVQcPHjQ85+PjoxITEw2PX3rp\nJfXBBx8opZSKjo5WnTt3Vp9++qnSarXqp59+Ur1791Zz5sxR+fn56syZM6pbt27q/PnzlR732n1d\nW4/jxo1TaWlp6sqVK2rYsGFqxYoVSiml/vrrLxUUFKSOHTum9Hq9WrNmjQoJCVHFxcUV9p2fn68G\nDBig1qxZo/R6vTp58qQKDAxUZ8+erXDsW+03JCREPfDAAyojI0OlpqaqoKAgFRERoU6dOqWKi4vV\nlClT1CeffKKUUiolJUUFBASoXbt2KaWUioqKUgEBASozM1MppdSkSZNUaGioOnfunCoqKlKTJk1S\n77777g0/x9mzZ6vPPvtMKaVUUVGROnToUKV1GRcXp3x9fVVUVJTSarXqiy++UKGhoaqkpMRw3JUr\nV1Za9lo//fSTmjx5crltX331lXrsscfKbXv88cfVV199pZRSqmfPnuro0aOG544fP678/PyqVPZa\nBQUFys/PT507d86wbcyYMWrTpk1KKaXeeOMN9cQTT6js7GyVl5enZs6cqd577z2llFJHjx5VPXv2\nNPwPpqamqri4uErfe1ZWlurVq5dat26d0ul0asOGDapXr14qKyvL8PqQkBB19uxZpdPpVE5OjvLz\n81MJCQlKKaXS0tIM59H13nrrLbVgwYJy2yr7brjq3LlzKioqSpWUlKjMzEw1adIktWjRIqWUUsXF\nxSokJER9++23SqvVqi1btqguXbrc1nkbHh6uUlJSVFFRkeGYr7/+uvr3v/9daTy1qdG1JKpCo9Hw\n1FNPYWFhQZ8+fbC2tmb48OE4Ojri5uaGv78/J0+eBODHH39kxowZtG3bFhMTE2bMmEFMTAwXL168\n4f7VNb9WNmzYwMMPP0yLFi2wtrZm9uzZbNq0qdLr3WZmZqSlpZGUlISpqSk9e/asdJ+VMTc3Z+bM\nmZiamnLfffdx+fJlpk6dirW1Nd7e3nh7e9/0135lpkyZgrOzM/b29oSEhHDq1CkAVq5cyYQJE+jW\nrRsajYbw8HAsLCw4evRohX3s2LEDT09PwsPD0Wg0dOrUiSFDhvDrr79WeG1V9jtp0iSaNWuGq6sr\n/v7+9OjRAx8fH8zNzQkNDTXEuG7dOoKDg+nXrx9Q+gu9a9eu7Ny507Cv0aNH06pVKywsLBg2bJih\n7FXX1rmZmRnJycmkpqZiYWGBn59fpXW2efNmgoODCQoKwtTUlEceeYTCwsJyv5DvVH5+PnZ2duW2\nNWnSxPCL9/rn7ezsyM/Pr1LZa1lZWTFo0CDDr/SEhATi4+MNl25WrVrFvHnzsLOzw8bGhhkzZhhe\nu2rVKsaOHUtQUBAArq6utG3bttL3ExkZSZs2bRgxYgQmJiYMHz6cdu3asWPHDsNrIiIi8PLywsTE\nBFNTU0xNTTl9+jRFRUU4Ozvj5eVV6b5zcnKwtbWtsP2pp54iICCAXr16ERAQwMqVKwFo1aoVQUFB\nmJmZ4ejoyNSpUzlw4ABQ2iLV6XRMmjQJU1NTQkND6d69u2GfVTlvp0yZgpubW7mWja2tLdnZ2ZXG\nX5vMjB1AXeXk5GT428rKCmdnZ8NjS0tLwz/XhQsXWLhwIW+99RZQdr06NTUVDw+PWx7n0qVLNG/e\n3PC4RYsWaLVa0tPTcXV1LffaRx99lI8//pjp06ej0WgYN25cpdemK+Pg4IBGozG8n8re49X3VFXX\nlre2tiYtLQ0orZO1a9caRucopdBqtVy6dKnCPi5cuMCRI0cICAgwvFan0xEeHl7pa2+132tjsrS0\nrPD42s9t8+bNhi+cq/u6+uUFlPvMra2tb1o/c+fO5YMPPmDs2LGGfoPKRs5c/3lrNBo8PDxITU29\n4b6rysbGpsKXem5uLk2aNKn0+dzcXGxsbKpU9nrDhw/n7bff5sknn2TDhg0MHjwYCwsLMjMzKSgo\nKPfe9Xq9IaGmpKQwYMCAKr2f6+sKoHnz5uXqyt3d3fC3tbU177//PkuXLmX+/Pn07NmTuXPn0q5d\nuwr7tre3Jy8vr8L2Tz/9lN69e1fYnpmZyRtvvMHBgwfJz89Hp9Ph4OAAlPZvuLm5lXv9tf/7VTlv\nr30fV+Xl5WFvb19he22TJHGX3N3deeKJJ6rUAX71S/parq6u5fowkpOTMTMzK/cFdZWNjQ3/+Mc/\n+Mc//kFsbCyTJ0+me/fu9O7dG2trawoKCgyvTUtLq/TEqw3u7u7MnDmTxx9//Jav9fDwIDAwkKVL\nl1brfqty3PDwcBYsWHDbZSv7HJ2cnAx9VYcOHWLatGkEBATQsmXLcq9zdXXlzJkz5bZdvHixWj6r\n9u3bs2zZsnLbTp8+behD8fb2JiYmhm7dugGlfUrt27e/adnKRlgB3HvvvcybN4+YmBg2btzI/Pnz\nAXB0dMTa2poNGzZU+JEDpZ/htX1n17q+Xl1dXdmyZUu5bRcuXKB///43LNO3b1/69u1LcXEx77//\nPq+++irLly+vcKyOHTuydu3aCttv1CJ/99130Wg0bNiwAXt7e7Zu3WroN3FxcamQ5C9evEirVq0M\n7/lOztu4uDhGjhx5W2VqglxuqsStLt1ca+LEiSxevJizZ88Cpc3Yyi6VADRr1gwTExMSExMN24YP\nH86yZctISkoiLy+P999/n+HDh2NiUvGjiYyMNJS1sbExNK+htON6w4YN6PV6du3aZWgKVwdnZ+cb\n/mNXZvz48axYsYJjx44BpZcydu7cWekv8eDgYOLj41m7di1arZaSkhKOHz9u6My90/3eysiRI9m+\nfTu7d+9Gr9dTVFTE/v37q/SLvrLP8ddffzWUtbe3x8TEpNLPcNiwYURGRrJv3z60Wi1Lly7F0tIS\nX1/fKsWt1+spLi5Gq9WW+xsgICAAExMTvv32W4qLiw2/XAMDAwEIDw9n2bJlpKamkpqayrJlyxg9\nevRNy1b2qxrA1NSUsLAw3n77bbKzs+nbty+AoYW7aNEiMjMzAUhNTTV0wI4dO5bVq1ezb98+lFKk\npqYaPuvrz7MBAwZw7tw5Nm7ciE6nY9OmTcTFxRESElJpTBkZGWzfvp2CggLMzMwM/yOV6du3LydO\nnCjXoX4zeXl52Nra0qRJE1JTU8v9qPH19cXU1JTly5ej0+nYunWr4RyFOztvi4uLOXHihKFejanR\nJYnKfgXe6jU3ezx48GAee+wxnn/+efz9/Rk5ciR//PFHpfu1srJi5syZTJw4kYCAAI4dO8bYsWMZ\nNWoUkyZNIjQ0FGtr6xsOu0tISODhhx/mnnvuYeLEiTz00EP06tULgJdffpnt27fTq1cvNm7cyODB\ng+/qPV5r7NixnD17loCAAJ5++ulbvr5r1668/vrrLFiwgICAAMLCwgyjaK5na2vLl19+yaZNm+jX\nrx/9+vXj3XffrfSf91b7vZ335O7uzqeffsrixYsJCgoiJCSEL7/80vAD4WZlK/scjx8/zrhx4/Dz\n8+Opp57i5ZdfpkWLFhXKtm3blnfeeYfXX3+doKAgIiMj+eyzzzAzM7vlcQHWrl1L9+7dWbBgAYcO\nHaJHjx6GG0HNzc359NNPWbNmDQEBAaxevZpPP/3UsO8JEyYQEhLCyJEjGTlyJCEhIYwfP75KZSsz\nfPhw9u7dy7Bhw8olxBdeeIHWrVszfvx4/P39mT59OgkJCQB0796dRYsWsWjRInr27MmUKVMM/XdT\npkzh119/JTAwkIULF+Lg4MBnn33G0qVL6d27N0uXLmXx4sU0bdq00rrS6/V89dVX9O/fn969e3Pg\nwAFee+21SmN3cnKid+/ebN26tdz2J554wnBPlZ+fH7NmzQLg6aef5q+//sLf35+ZM2cSFhZmKGNu\nbs7HH3/MypUr6dWrFxs2bGDgwIGG/oXbPW+hdARcYGBgpaPYaptG3c7P5ts0f/58IiMjcXJyYv36\n9QDExMTw2muvUVRUhJmZGa+99pqh+SuEELUlNjaWl156ydA5XZ3Gjx/PxIkTiYiIuKPyDzzwAAsX\nLrzpPTO1pUaTxMGDB7G1tWXu3LmGJPHII48wbdo07r33Xnbu3MmSJUv49ttvayoEIYSocQcOHKBt\n27Y4Ojqybt06/vWvf7F169ZK+xbrmxrtuPb39yc5ObncNo1GY7ihKScnp8KoACGEqG/i4+N57rnn\nyM/Pp1WrVnz00UcNIkFADbckoHS0zsyZMw0tidjYWB599FHDnc8rVqyo0lBRIYQQta/WO65/+OEH\nXn75ZSIjI5k3b55h6JwQQoi6p9aTxC+//GIYeTN06NByQ8VupoYbPEIIUe9l5xVz/Gx6tX5f1vjN\ndNcH6+bmxv79+wkICGDv3r20adOmSvvRaDSkpeXc+oWNgIuLndTF/0hdlJG6KNPY6iL9SgFb9p9n\n17ELFJfo+ee0XrRyK51mxcXF7halb65Gk8ScOXOIjo4mKyuL4OBgZs2axeuvv84bb7yBXq/H0tKy\nVtZxEEKIhijjSiG/7I5j71+p6JWimb0lw4Jb09K18ulU7kSNd1xXp8b0y+BmGtuvpJuRuigjdVGm\noddFXmEJG/eeY+vBJLQ6Pc2dbRkW2IrAzm6YmZbvRajTLQkhhBDV61RCJv9de4LcghKa2VsS0a8d\nQV3cMTG59WwSd0KShBBC1BORR5JZvuU0AGODvRjc0xML88rnp6oukiSEEKKO0+n1/LQ9lt8PnqeJ\ntTlPj+5Gh5YOtXJsSRJCCFFHXczIY/fxi+z9K4Ws3GI8nGx4dlwPXB2say0GSRJCCFHHFBRp+WL9\nSY6cTQfA2tKMgX4tGN3fCxur2v3aliQhhBB1SGZ2IR+sPEZSWi7enk0Z5OfJPe2da7zv4UYkSQgh\nRB1x/lIuH6w8yuWcIkL8WvDQ4A41NmqpqiRJCCGEESmlOJt8hT3HU4g+mUpRiY7xId6EBbSs0iJp\nNU2ShBBCGMmhvy+xMjKWS5dL16d3tLNk+vBO9PKpuD64sUiSEEKIWqbXK9b8EcfGvecwNzMhqIsb\nfbp50KmVo9EvL11PkoQQQtSi/MISFq87yfG4DFwdrZk1uhstXKpvrqXqJklCCCFqScaVQt798Qgp\nmfl0bdeMx0d2wdbK3Nhh3ZQkCSGEqAUX0vN498cjXM4pYkivlowP8a5zl5YqI0lCCCFqWNyFbD5Y\neZTcghLGhXgxLLC1sUOqMkkSQghRA5RSJKTksPvYRfb8dZESrZ5pw3zo16O5sUO7LZIkhBCimh2L\nzWBV5FmS0vIAcGhiweQRHbmng4uRI7t9NZok5s+fT2RkJE5OTqxfv96w/dtvv2X58uWYm5szYMAA\nXnjhhZoMQwghas3Wg+f5YdsZTDQaenZ0oV93D7q0bYapicmtC9dBNZokRo8ezeTJk5k7d65hW3R0\nNDt27GDDhg2YmZmRmZlZkyEIIUSt0OsVK7afYevBJOxtLXh2bHfaetgbO6y7VqNJwt/fn+Tk5HLb\nfvjhBx577DHMzEoP3axZs5oMQQghaoRerzgWm0Fyei5pWYWcS83hXEoOzZ1teW5sd5xrcTrvmlTr\nfRIJCQkcPHiQ999/H0tLS+bOnUu3bt1qOwwhhLhjWp2epRtPEX0ytdz27l5OzBjRGZs6fu/D7aj1\nJKHT6cjOzuann37i2LFjPPfcc2zbtq1KZe92Qe+GROqijNRFGamLMjVVF0UlOv799QEOnkqlU5tm\njB3UHrdmNrg52mBl2fDGAtX6O3J3d2fIkCEAdO/eHRMTEy5fvoyjo+Mty6al5dR0ePWCi4ud1MX/\nSF2UkbooU1N1UVCk5eOfjxGTmEWXts14OqIblhal6zzkZBdQF2v/bpNljXe3K6XKPR48eDB79+4F\nID4+Hq1WW6UEIYQQxhR/MZvXvz5ITGIWPTu68MyY7oYE0ZDVaEtizpw5REdHk5WVRXBwMLNmzWLM\nmDHMmzePESNGYG5uzltvvVWTIQghxF3R6fVs3HuO9XsS0OkVQ3q1ZFyIV70d0nq7NOr6n/p1mDSl\nS8llhTJSF2WkLspUV12kXylg8boTxCZn42hnySPDO9G5Tf0akXm3l5saXi+LEEJUgyNn01m64SR5\nhVoCOrkyOaxjnZ+xtSZIkhBCiGsUlehY+0c8v+5PxMzUhKlDO9K/R/M6sZSoMUiSEEIIILeghO2H\nkth6KIncghLcmtnwxKgutHJr3MOKJUkIIRo1vV6xYW8Cm/ado7hEj62VGff3acOwwFZYN8D7Hm6X\n1IAQotG6klfM5+tOcOrcZRyaWDCmf2v69fDAykK+Gq+SmhBCNEp/J17ms7UnuJJXjK+3M4/c36lR\ndkzfiiQJIUSjUqLV88vuOH6NTkSDhnEhXgwNaNVoO6ZvRZKEEKLRSEzNYcmGkySl5eHc1IrHRnSm\nvaeDscOq0yRJCCEavBKtno17E9i49xw6vWKAb3PGh3hLx3QVSA0JIRq00+ez+PrXGC5m5ONoZ8nU\noT5093Iydlj1hiQJIUSDlH6lgLV/xLPnrxQ0wCA/T0YPaCeth9sktSWEaFCu5Baxenc8m6NKJ+Tz\ndLFlylAfvFs0NXZo9ZIkCSFEvafT6zkRn8nu4ykcOZOGVqdwcbAi/N52BHZ2w8RERi7dKUkSQoh6\nq6hYx/Y/k9hy8DxXcosBaO5sS/gAL3zbNcPMtHFM512TJEkIIeqdohIdO/5MZnP0OXLyS7C2NCPE\nrwX3dvOgjbsdrq72Mm16NZEkIYSoV1Iz8/lw1TFSMvOxtjRlZN82DOnVEhu5W7pG1GhbbP78+fTp\n04cRI0ZUeG7p0qX4+PiQlZVVkyEIIRqQUwmZvPHNQVIy8xnU05O3ZvYhvF87SRA1qEaTxOjRo1m6\ndGmF7SkpKURFRdG8efOaPLwQooEo0erYevA87/10lMJiHdPu8+Gh0A40sZbkUNNq9HKTv78/ycnJ\nFbYvWrSIuXPn8sQTT9Tk4YUQ9ZhSiuNxGUSfvMSRs2kUFOloYm3O06O70aGlTKVRW2q9T2L79u14\neHjQsWPH2j60EKIeWRUZy+boRACc7K0Y4NuCQX6eODW1MnJkjUutJonCwkI+++wzvvzyS8M2pVSV\ny9/tgt4NidRFGamLMg2lLk7EZfDr/kQ8nG2Z/aAfHVs53vYsrQ2lLoytVpNEYmIiycnJjBo1CqUU\nqampjBkzhpUrV+LkdOu5VGRIWykXFzupi/+RuijTUOqiqFjHe98fAgXThvrgZGNOenrube2jodRF\ndbjbZFnjSeLalkKHDh3Ys2eP4fHAgQNZs2YNTZvK7fJCiFKrdsZy6XIBQwNa4e0p3w3GVqOjm+bM\nmcOECROIj48nODiYn3/+udzzGo3mti43CSEatphzl9l2KAkPJxvC+7U1djiCGm5JvPvuuzd9ftu2\nbTV5eCFEPXL4TBpfbjyFRgOPDO+MhbmpsUMSyB3XQggj0+r0rIqMZcuB85ibmTD9vk60a25v7LDE\n/0iSEEIYTWJqDl//+jfxF7Nxb2bDE+FdaenaxNhhiWtIkhBC1LrcghLW7Ioj8kgySkFQFzcmh3XE\nykK+kuoa+USEELXm6uytG/cmkFeoxcPJhomD29O1rSwnWldJkhBC1LgSrY7IIxfYtPccV/KKsbY0\nY8JAbwb29JQ1H+o4SRJCiBqhV4qzSVfYdzKVA6dSySvUYmlhyv19WhMW0Apbmbm1XpAkIYSodhcz\n8vho1TFSLxcA0NTWgvt6tyYsoCV2NhZGjk7cDkkSQohqdTmniPd+PEJGdhG9u7jRt6sHnVo7yjrT\n9ZQkCSFEtckvLOG9n0oTRET/dozo08bYIYm7JD1GQohqUVSi46NVx0hOy2OQnyf3B7U2dkiiGkhL\nQghxx/ILtRw9m87hs+n8FZdBYbGOXj6uTBzc/ran9hZ1kyQJIcQdOZt8hU9+PkZ2fgkALg5WDOrp\nyci+baX/oQGRJCGEuG37TqTw5aYY9HrF/X1aE9jJjebOttJ6aIAkSQghqkwpxdrd8azbk4C1pSlP\nhHeTu6UbOEkSQogqUUqx/PfTbP8zGeemVjw7rgctnG2NHZaoYZIkhBC3pFeK5VtOs+NwMp4uTXhh\ngi/2tnJTXGNQo0li/vz5REZG4uTkxPr16wF4++232bFjBxYWFrRq1Yo333yTJk1kamAh6iq9Uny3\n5TSRh5Np6VqaIOSu6cZDo2pw/dCDBw9ia2vL3LlzDUkiKiqK3r17Y2Jiwn/+8x80Gg1z5syp0v5k\nYfNSssh7GamLMtVVF4XFWtbtTiD2whWy84rJzi+moEhHK9cmvDDxHppY1/05l+S8KOPiYndX5Wu0\nJeHv709ycnK5bX369DH87evry2+//VaTIQghbkNs8hW+WH+SS1kFaDRgZ22Ok70VLVya8FBoh3qR\nIET1MmqfxKpVqxg+fLgxQxBCAHq9Yn1UAuv3JKCUYlhgK8L7tcPcTCZlaOyMliT++9//Ym5uzogR\nI6pc5m6bTQ2J1EUZqYsyd1IX+YUlvPPdIQ6eSsXZwZrZE/3o5u1cA9HVLjkvqodRksSaNWvYuXMn\n33zzzW2Vk2uMpeR6axmpizJ3UheZ2YV8uOoY5y/l0qVtM2aO6oKtlXm9r1M5L8rU6T4JKB1bfa1d\nu3axZMkSvvvuOywsZISEEMZyLiWHD1cdJSu3mGDf5jwY2kFWiRMV1GiSmDNnDtHR0WRlZREcHMys\nWbNYvHgxJSUlTJ8+HYAePXrwz3/+sybDEEJc5+/Ey3y46hhFxTrGh3gTFtBSptQQlarRIbDVTZqP\npaQpXUbqokxV6+LI2XT++8tf6PWKx0Z0JqCTWy1EV7vkvChT5y83CSGMK7eghKycIrLzi0lMzWVV\nZCxmphqeGdudbu1k3iVxc5IkhGjA/jh6gW9++xudvuyCgY2lGc+O6057TwcjRibqC0kSQjRQx2Iz\n+PrXv7GxMqNXJ1fsbSywszGnWzsnXBysjR2eqCckSQjRACWkZPPfX/7C1FTDs2O749WiqbFDEvWU\njHcTooFJSsvlg5XHKC7RMWNEF0kQ4q5IS0KIBuDvxMus3h3PgRMppF4uAOCh0A707Ohi5MhEfSdJ\nQoh6rKhEx4/bzxJ5uHQiTUsLU+5p70xAJzcCOze8oa2i9kmSEKKeSrqUy2frTnAhPQ9PF1seH90d\nN3tLuWtXb7olAAAgAElEQVRaVCtJEkLUMyVaPb/tT2TdngS0Oj2D/DwZF+JFi+YOcgOZqHaSJISo\nR04mZPLdltOkZOZjb2PO1GFduKe99DuImiNJQoh6oLBYy/LfT7PneAoaYKBfC0b3b4eNlSwCJGqW\nJAkh6rjE1Bz++8tfpF4uoLWbHVOHdaSNu72xwxKNhCQJIeoopRTb/0zmx+1n0OoUQwNaMXpAO+mY\nFrVKkoQQdZBSip93xrFp3zmaWJvz6P2d6O5V/1eLE/WPJAkh6hi9UqzYeoath5Jwc7TmhQn34NTU\nythhiUbqlkni/PnzrFq1iujoaFJSUrC0tMTHx4ewsDCGDBmCmdmNdzF//nwiIyNxcnJi/fr1AFy5\ncoXnn3+e5ORkPD09+eCDD7Czk7VoReNUotVzObeIgkItCoVSsPNIMruOXqSFiy0vPOBL0yaWxg5T\nNGI3XXTo//7v/zhx4gRDhw7lnnvuwdnZmaKiImJjY9m9ezcnT57kn//8J76+vpWWP3jwILa2tsyd\nO9eQJN555x0cHBx47LHH+Pzzz8nOzuaFF16oUrAyBryULKhSpj7WRVpWAd/8GsP5S7lk55dU+prW\nbnbMmeBLE+uqj16qj3VRU6QuytTookODBg1iwYIFFbZ37NiR++67j6ysLM6fP3/D8v7+/iQnJ5fb\ntm3bNr777jsAIiIimDx5cpWThBD1XWzyFT76+Rg5+SW4OlrTwqUJjnaW2FiZYaLRoNGAjZU5g/xa\nyPBWUSfcNEkMGDDgpoUdHBxwcLi9hUsyMzNxdi7tgHNxceHy5cu3VV6I+upAzCWWbDiJTqeYPKQD\nIX6exg5JiFuq0li6f//73+Tk5KDVannwwQfx9fVl7dq1NR2bEPVadl4xO48k882vMfxr2YHS9R1M\nNDw7rrskCFFvVGl0U1RUFC+99BKRkZG4ubnx/vvvM2PGDEaNGnXbB3RyciI9PR1nZ2fS0tJo1qxZ\nlcve7bW1hkTqokxdrIviEh3zv9hHSkY+AGamJnRq04wnx/agjUfN3QhXF+vCWKQuqsdtDYE9cOAA\noaGhuLm5odFoqlTm+n7xgQMHsnr1ambMmMGaNWsYNGhQlY8vHVGlpFOuTF2ti1+jE0nJyCeoiztD\nerWkhYut4Sa4moq3rtaFMUhdlLnbZFmly01OTk688sorbNq0ib59+6LVatHpdLcsN2fOHCZMmEB8\nfDzBwcH8/PPPzJgxg6ioKMLCwti7dy8zZsy4qzcgRF2TW1DChqgEbK3MeDC0Pa3d7eQuaVFvVakl\n8e6777Ju3TrGjh1L06ZNSUpKYtq0aVUqV5lly5bdVpBC1CcbohLIL9LywEBvbGWEkqjnqpQkmjVr\nxsMPP2x47OnpiaendLwJcb20rAK2/5mEc1MrBkrntGgAbtoGfvLJJzl27Filz+Xm5vL111/z448/\n1khgQtRHa3bFodUpRvdvh7mZXGIS9d9NWxLPPPMM7777LgkJCXTv3h0nJyeKioqIi4sjOTmZCRMm\nMHHixNqKVYg6Sa9XHI1NZ+vBJE6du0xrdzsCZH1p0UDcNEn4+PjwxRdfcPHiRfbv309qaiqWlpYM\nHTqUnj17YmFhUVtxClEnnU26wpINJ7mUVQBAp9aOPBTaAZMqjv4Toq6rUp+Eh4fHHd0TIURDlltQ\nwqe/HCc7r4T+PTwY3LMlnq5NjB2WENWqShdNMzIyeOGFF3jooYcAiImJ4YcffqjRwISo677b8jdZ\nucWE92vLw8M6SYIQDVKVksQrr7xCz549yc7OBqBdu3Z8//33NRqYEHXZvpMp7D91Ce8WTRnWu5Wx\nwxGixlQpSaSmpjJx4kRMTU0BsLCwwMRERm6Ixikzu5DvfjuNpbkpj97fCVP5XxANWJXO7usXFsrO\nzq4w3YYQjUFqZj6frD5OfpGWCYO8cXW0MXZIQtSoKnVcDxkyhP/7v/8jLy+P1atX8/333zNmzJia\njk2IOkOr0/Pb/kTW7k5Aq9PTp6s7/Xs0N3ZYQtS4KiWJRx99lHXr1pGdnc3OnTuZPHmyjHYSjYJS\nimOxGfy8M46ktFzsbS14KLQD/h1dqjzJpRD1WZVngR05ciQjR46syViEqDZanf6uJtXT6vT8FZfJ\nuj3xJKSUziZ6b3cPmY9JNDpVShIZGRl89913JCYmotVqDds//PDDGgtMiDuVmJrDv5Yd4KmIbvh1\ncKlSGaUUv/wRz7HYDC7nFBrWntYAvXxcGdG3DZ4uMsRVND5VShJPPvkknTt3JigoyDDCSYi66kzS\nFZSCY7EZVU4S6/YksD4qAXMzE5rZWdLc2Ra3ZjYM6ukpyUE0alVKEgUFBbz22ms1HYsQ1eLS5dIp\nMs6lVm3RmX0nU1i7Ox7npla8MsUfe1uZbkaIq6qUJHr06MHff/9Nx44dazoeIe5a2v/mUUpOy63Q\nN/H7gfOsj0qgh7cTfbt6YGZmwpcbY7C2NOXZsd0lQQhxnSoliQkTJjBp0iTc3d2xtLQ0bF+1atUd\nH3jZsmWsWrUKjUZDhw4dePPNN2XCQFEtrk62p9UpLqTn0cqtbPnGvSdSyC0oYc/xFPYcTwFAo4Fn\nRvWghVxWEqKCKiWJF198kZkzZ9K5c+dq6ZNITU3l22+/ZfPmzVhYWPDcc8+xadMmwsPD73rfonHT\nK2VoSUDpJaerSaKwWEtiai5eze0ZG+zF7uMXOR6bQXj/dnRt52SskIWo06qUJCwtLXnkkUeq9cB6\nvZ6CggJMTEwoLCzE1dW1WvcvGqcrucWUaPU42VuRkV1IYmqu4bnYC9nolaJDSwc6tnKkYytHI0Yq\nRP1QpYHk/fr1Y9euXdV2UDc3N6ZNm0ZwcDD9+/fHzs6OPn36VNv+ReN16XI+APd0cMZEoynXeX3m\nfBYA7T0djBKbEPVRlVoSP/30E59//jm2trZYWFiglEKj0bB37947Omh2djbbtm1jx44d2NnZ8cwz\nz7B+/XpGjBhxR/sT4qqr/RGeLk3wcLLhfGoueqUw0Wg4k3QFAG/PpsYMUYh6pUpJ4ueff67Wg0ZF\nRdGyZUscHEp/0YWGhnL48OFbJgkXF7ubPt+YSF2UubYu8or1ALRv04xzl3JJPpRECRrcm9kSdzGb\nVu52tG3VzFih1jg5L8pIXVSPKiWJFi1aVOtBmzdvztGjRykqKsLCwoJ9+/bRrVu3W5ZLS6vauPeG\nzsXFTurif66vi4Tk0ktKFoCbgzUAh0+l4OZoQ1GxjnYe9g227uS8KCN1UeZuk+VNk8SLL77IO++8\nw5gxYyqdzOxOh8B2796dsLAwwsPDMTMzo3PnzowfP/6O9iXEtdKyCjAz1eBoZ0lrt9IhrYmpuVzJ\nLQagvVxqEuK23DRJXLp0CYB//OMf1X7gp59+mqeffrra9ysat7SsQpyaWmNioqGla+kvqHMpOVhb\nlp7qHaTTWojbctMkcXW50oCAgFoJRoi7kV+oJbeghLYe9gDYWJnh6mBNYmoOJiYamtlb4tTUyshR\nClG/yLqLosG4ehOd6//6IgBauduRV6glJ79EWhFC3IGbtiROnz5NUFBQhe13OwRWiJpwdfiri2NZ\nkmjt1oSDMaWXTaU/Qojbd9Mk0aZNGz7//PPaikWI25KTX4xdcdn6JldvpLu2JdH6mnmb2reUloQQ\nt+umScLCwqLah78KUR2KS3S8siQab08HZo0uHT6dVklL4uq8TbZWZjR3tq39QIWo526aJMzNZZlG\nUTfFJGaRk1/C4dNpnEvJobW7nWEdCZdrOqftbS3o29UdFwdrTGRNaiFu2007rn/66afaikOI23Is\nNt3w9+8HzwOlLQlHO0sszMvPVPzI/Z0ZeW/bWo1PiIZCRjeJekcpxbHYDKwtzWjh0oTok6mkXykg\nM7sIl2v6I4QQd0+ShKh3Lmbkk36lkC5tmzGqfzt0esWqyFgU4OIg90EIUZ0kSYh651hsBgA9vJwI\n6dkSWysz9p8qHebqKi0JIaqVJAlR71ztj+jazgkrSzMG+JaNwLt2ZJMQ4u5JkhD1Sn6hljNJV2jr\nYUdT29I10Qf6tcDUpHTkkquDjTHDE6LBkSQh6pWTCZno9Ipu16xJ3czeij5d3bG2NMPDSZKEENWp\nSutJCFFXGPojvJ3LbZ8ytCMPDGxvmO1VCFE95D9K1Bt6pTgel4G9jTmt3csvpGJqYoKNlTSMhahu\n8l8l6o0T8ZlcySumWzsnuXtaiFpitJZETk4OL7/8MmfOnMHExIRFixbRo0cPY4Uj6jCtTs+GqAQ2\nRJ1DAwR1dTd2SEI0GkZLEgsXLmTAgAF89NFHaLVaCgsLjRWKqMMupOexdONJ4i/m4GRvySPDO+PT\n2tHYYQnRaBglSeTm5nLw4EH+/e9/lwZhZkaTJk2MEYqoo/ILtazbE8+2Q0no9IqgLu48FNoBGyvp\nRhOiNhnlPy4pKQlHR0fmzZtHTEwMXbt25eWXX8bKSqZUaOwKirTsO5nK2j/iyM4vwbmpFRMHteee\nDi7GDk2IRkmjlFK1fdC//vqLBx54gBUrVtCtWzcWLlyInZ0dzzzzTG2HIuoApRSH/05j+8Hz7P3r\nIsUlOqwsTBk3qAPhA7wqzOoqhKg9RmlJuLu74+7uTrdupYvFhIWFsWTJkluWS0vLqenQ6gUXF7sG\nUxdZuUV8ufEUf8VnAuDmaE1QV3f6dW+Oo50lV7Lyb1q+IdXF3ZK6KCN1UcbFxe7WL7oJoyQJZ2dn\nPDw8iI+Pp23btuzbtw8vLy9jhCKM6PDpNL7aHENuQQld2zVjVN+2tGtuj0aGtwpRZxitF/CVV17h\nhRdeQKvV0rJlS958801jhSJqWVGxjhXbz7DzyAXMzUx4KLQDA/1aSHIQog4yWpLw8fHh559/Ntbh\nhZEkpGSzeN1JUjPz8XRpwuMjO9PCRUa2CVFXyXhCUSu0Oj2/7U/klz/i0ekVYQEtGd3fC3Mzuelf\niLpMkoSoUVeXGv1x+1lSMvNp2sSCR4d3pkvbZsYOTQhRBZIkRI1JTM1hVWQsf8VnotFAiF8LIvq1\no4m1ubFDE0JUkSQJUe0SU3NYtyeBP0+nAdCljSMPDGqPp/Q9CFHvSJIQ1SY9q4CVkbEciCldb9qr\nuT2j7m1Ll7bNZOSSEPWUJAlx1wqKtGzad47f9p9Hq9PT1sOO8H7t6CrJQYh6T5KEuGNKKfaeSGFl\nZCxXcotxtLNkbLAXgZ3dZL0HIRoISRLijpxLyWH576c5m3wFczMTRvZtw7DA1lhayDxLQjQkkiTE\nbSko0rJmVxzbDiWhgJ4dXHhgoDfODtbGDk0IUQMkSYgqOxabwbe/xZCRXYRbMxsmDelAlzZyv4MQ\nDZkkCXFLl3OK+GnHWaJPpmJqouH+Pm0Y0ac15mZyaUmIhk6ShLghrU7P1oNJrN0TT1GxjrYe9kwb\n5oOnq9zvIERjIUlCVKCU4sjZdH7eGceF9DyaWJszYag3/Xo0l1FLQjQykiREOafOXWb1zlhiL2Sj\n0cAA3+aMGeAlU2kI0UhJkhBA6Qpxy7ec5tD/ptLo2cGF8P7taOFsa+TIhBDGJEmikVNK8cexi/y4\n/SwFRVraezZlwqD2tPWwN3ZoQog6wKhJQq/XM2bMGNzc3Pjss8+MGUqjdC4lhxXbzvD3+SysLEyZ\nHNaRAb7S7yCEKGPUJPHNN9/g5eVFbm6uMcNodDKzC1mzK46ov1JQgK+3M5OGdKCZvZWxQxNC1DFG\nSxIpKSns3LmTmTNn8tVXXxkrjEalqETHb9GJbNp3jmKtHk+XJjwwyFtuiBNC3JDRksSiRYuYO3cu\nOTk5xgqh0VBKcSDmEj/tOEtmdhFNbS14KLQdfbt5YGIil5aEEDdmlCQRGRmJs7MznTp1Ijo6usrl\nXFzsajCq+qUqdaGUYv+JFFZsPc3Z81mYmZowdmB7xg1qj41VwxnSKudFGamLMlIX1UOjlFK1fdD3\n3nuPdevWYWpqSlFREXl5eYSGhvL222/ftFxamrQ6oPTkv1VdHDmTzpo/4jh/KRcN4O/jypgB7XB1\ntKmdIGtJVeqisZC6KCN1UeZuk6VRksS19u/fz5dfflml0U3yoZe62T/A5Zwilv9+mj9Pp6HRQGAn\nN4b3adNg73eQL4MyUhdlpC7K3G2SkPskGgi9Uuw8coFVkWcpKNLRoaUDU8I60ryBJgchRO0wepII\nCAggICDA2GHUa38nXuaHbWdITM3F2tKMqUM7yjxLQohqYfQkIe7cpawCVu44y6G/S6fSCOrixthg\nbxztLI0cmRCioZAkUQ/lF5awKjKWLQcS0eoUXi3smTioA+2ay1QaQojqJUmiHlFKEfVXCqt3xXE5\npwhHO0vGhXgR2MkNjVxaEkLUAEkS9URyeh7f/vY3p89nYWFuyqh72zI0sBWW5rI6nBCi5kiSqOOK\ninWsj0rgt/2J6PSKe9o78/QD96DR6owdmhCiEZAkUUcppYg+mcrKyFgu5xThZG/FQ6Ed8G3vjIuj\njYwBF0LUCkkSddC5lByW/36as8lXMDM14f4+bRjeuzWWFnJpSQhRuyRJ1CG5BSWs3hXHzsPJKKBn\nRxceCPHG2cHa2KEJIRopSRJ1gFanZ/exi6zeFUduQQkeTjY8FNqBzjKFtxDCyCRJGJFWpyfqrxQ2\nRCWQfqUQSwtTxod4M9jfEzNTE2OHJ4QQkiSM5fCZNH7Yeob0K4WYmZowuKcn9wW1xqGJ3C0thKg7\nJEnUsit5xXz/+2kOxFzC1ETDoJ6e3Ne7tUylIYSokyRJ1BK9XvHHsQusiowlr1CLVwt7Hh7WqcFO\n4S2EaBgkSdSC0+ez+H7raRJTc7G0MOWh0A6E+LWQWVqFEHWeJIkalJZVwM87Y9l/6hIAQV3cGRvs\nJZeWhBD1hiSJGpBXWMLGqHNsPXQerU7R1sOeBwe3x6tFU2OHJoQQt8UoSSIlJYW5c+eSnp6Oqakp\n48aNY8qUKcYIpVoppfjj2EVW7jhLXqEWJ3tLxgzwIqCzm1xaEkLUS0ZJEqampsybN49OnTqRl5fH\n6NGj6du3L15eXsYIp1qkXs7n680xxCRmYWVhyrhgLwb7e2JuJlNpCCHqL6MkCRcXF1xcXACwtbXF\ny8uLS5cu1cskUVCk5fcD59m47xwlWj2+3s5MGtKBZvZWxg5NCCHumtH7JJKSkoiJiaF79+7GDuW2\nlGh17Dh8gY17E8jJL8He1oJH7++Af0cXWQBICNFgaJRSylgHz8vLY/LkyTz55JMMHjzYWGHclqIS\nHb9Hn+PnHWdJzyrA2tKMiGBvRvVvh42VubHDE0KIamW0JKHVann88cfp378/U6dOrVIZY66hUFyi\nY/ufyfy6P5HsvGIszEwI8WvBfb1bY2djUauxuLjYyXoS/yN1UUbqoozURRkXF7u7Km+0y03z58/H\n29u7ygnCmOIvZrNkw0kuZuRjZWHK8KDWhPq3xN62dpODEELUNqMkiUOHDrF+/Xo6dOhAeHg4Go2G\n559/nv79+xsjnBvS6vSs25PApr3n0CvF4J6ejOrXFlu5rCSEaCSMkiR69uzJqVOnjHHoKjubdIVv\nfoshKS0PJ3srpg/vRKfWjsYOSwghapXRRzfVNbkFJayKPMuuoxcB6N+jOQ8M9MbaUqpKCNH4yDff\nNf48ncayzTHkFpTg6WLL5LCOtPd0MHZYQghhNJIkKL0h7oetZ9h9/CLmZiaMC/Ei1L+lrA4nhGj0\nGn2SOJmQybLNMaRfKaSVWxMeG9FF1ngQQoj/abRJIuNKIT9uP8PBv9PQAMODWjPq3rbSehBCiGs0\nuiRRotXz6/5ENkYlUKzV49XCnkmhHWntfnc3nAghREPUqJLEsdh0vt96hkuXC7C3tWBymBdBXd1l\nGm8hhLiBRpEk0rIKWLHtDIfPpGOi0RDq35JR97bFxqpRvH0hhLhjDfpbsqhEx+Z959i0LxGtTk+H\nlg5MCu2Ap2sTY4cmhBD1QoNMEkopDv2dxo/bz5KRXYhDEwvGD/QmsJObTOMthBC3ocEliYSUbFZs\nPcPppCuYmmgYFtiK+/u0kTumhRDiDjSYb84recX8HBnLnuMXUcA97Z0ZH+KNWzMbY4cmhBD1Vr1P\nElqdnu2Hkli7J56CIh2eLrZMGNSezm2aGTs0IYSo9+p1kvg78TLfbjnNhfQ8bK3MmDSkAwN8m2Nq\nIjfECSFEdaiXSSK3oISfdpxl97GLaIBg3+ZE9G9X6yvECSFEQ2e0JLFr1y4WLVqEUooxY8YwY8aM\nW5ZRShF9MpXvt54ht6CElq5NmDrUh3bN7WshYiGEaHyMkiT0ej2vv/46y5Ytw9XVlbFjxzJo0CC8\nvLxuWCbjSgEfrTrG0dgMLMxNGB/iTWgvT7m0JIQQNcgoSeLYsWO0bt2aFi1aADB8+HC2bdt20yTx\n1NvbySvU0qm1I1OH+eDqYF1b4QohRKNllCSRmpqKh4eH4bGbmxvHjx+/aRm9gqlDO9K/R3O5IU4I\nIWqJUZKEUuq2yyx9JZTCvKIaiEYIIcSNGCVJuLu7c+HCBcPj1NRUXF1db1rGzsZCRi9dw8VFpja/\nSuqijNRFGamL6mGUXt9u3bqRmJhIcnIyxcXFbNy4kUGDBhkjFCGEEDdhlJaEqakpr776KtOnT0cp\nxdixY2/aaS2EEMI4NOpOOgiEEEI0CnKTgRBCiBuSJCGEEOKGJEkIIYS4oTqfJHbt2sXQoUMJCwvj\n888/N3Y4tSolJYUpU6Zw3333MWLECL755hsArly5wvTp0wkLC+ORRx4hJyfHyJHWHr1eT0REBDNn\nzgQgKSmJ8ePHExYWxuzZs9FqtUaOsHbk5OTwzDPPMGzYMIYPH87Ro0cb7XmxbNky7r//fkaMGMGc\nOXMoLi5uNOfF/Pnz6dOnDyNGjDBsu9l58MYbbzBkyBBGjRrFqVOnqnSMOp0krs7xtHTpUjZs2MDG\njRuJjY01dli1xtTUlHnz5rFp0yZWrFjB8uXLiY2N5fPPPycoKIjffvuNwMBAFi9ebOxQa80333xT\nbiTcf/7zH6ZNm8Zvv/2GnZ0dq1atMmJ0tWfhwoUMGDCAzZs3s3btWtq1a9coz4vU1FS+/fZbVq9e\nzfr169HpdGzcuLHRnBejR49m6dKl5bbd6DzYuXMniYmJbNmyhQULFvDaa69V6Rh1OklcO8eTubm5\nYY6nxsLFxYVOnToBYGtri5eXF6mpqWzbto2IiAgAIiIi2Lp1qzHDrDUpKSns3LmTcePGGbbt27eP\nsLAwoLQufv/9d2OFV2tyc3M5ePAgY8aMAcDMzAw7O7tGe17o9XoKCgrQarUUFhbi6upKdHR0ozgv\n/P39sbcvPwv29efB1e/Mbdu2ER4eDkCPHj3IyckhPT39lseo00misjmeLl26ZMSIjCcpKYmYmBh6\n9OhBRkYGzs7OQGkiuXz5spGjqx2LFi1i7ty5hrm7Ll++TNOmTTH530zA7u7ujeL8SEpKwtHRkXnz\n5hEREcGrr75KQUFBozwv3NzcmDZtGsHBwfTv3x87Ozs6d+6Mvb19ozsvrsrMzCx3HmRmZgJw6dIl\n3N3dDa9zc3MjNTX1lvur00lCbuEolZeXxzPPPMP8+fOxtbVtlBMcRkZG4uzsTKdOnQznhVKqwjnS\nGOpGq9Vy8uRJHnzwQdasWYO1tTWff/55o3jv18vOzmbbtm3s2LGDP/74g4KCAnbt2lXhdY2xbq5X\n2fdpVeqlTq9MdydzPDU0Wq2WZ555hlGjRjF48GAAnJycSE9Px9nZmbS0NJo1a/jref/5559s376d\nnTt3UlRURF5eHosWLSInJwe9Xo+JiQkpKSmN4vxwd3fH3d2dbt26ATBkyBC++OKLRnleREVF0bJl\nSxwcHAAYPHgwhw8fJjs7u9GdF1fd6Dxwc3MjJSXF8Lqq1kudbknIHE+loxe8vb2ZOnWqYdvAgQNZ\nvXo1AGvWrGkUdTJ79mwiIyPZtm0b7733HoGBgfznP/8hMDCQX3/9FWg8deHs7IyHhwfx8fFAab+M\nt7d3ozwvmjdvztGjRykqKkIpxb59+2jfvn2jOi+ubyHc6DwYNGgQv/zyCwBHjhzB3t7ecFnqZur8\ntBy7du1i4cKFhjmeqrLMaUNx6NAhJk2aRIcOHdBoNGg0Gp5//nm6d+/Oc889x8WLF2nevDkffvhh\nhc6rhmz//v18+eWXfPbZZ5w/f57Zs2eTnZ1Np06deOeddzA3Nzd2iDUuJiaGl19+Ga1WS8uWLXnz\nzTfR6XSN8rz45JNP2LhxI2ZmZnTu3Jk33niDlJSURnFezJkzh+joaLKysnB2dmbWrFkMHjyYZ599\nttLzYMGCBfzxxx9YW1vz5ptv0qVLl1seo84nCSGEEMZTpy83CSGEMC5JEkIIIW5IkoQQQogbkiQh\nhBDihiRJCCGEuCFJEkIIIW5IkoSocwYOHMjZs2dr5ViffPJJuWmk582bx/Lly+96v/PmzWPEiBHM\nnj37rvd1MzExMWzevLlGjyEaN0kSolH75JNPKCkpqdZ9pqens2XLFtavX897771Xrfu+3smTJ+84\nSej1+mqORjREkiREvREfH89jjz3GuHHjCA8PN0w9AODj48PixYsZO3YsoaGhbNmyxfDcb7/9xrBh\nwxg9ejSLFy/Gx8eHgoICFixYgEajYcKECURERJCbmwvA6dOnmTp1KmFhYbz00ks3jOeXX35hxIgR\njBo1ilmzZpGZmUleXh5Tp06lqKiIiIgIvv7663Jl1q5dy9NPP214rNPp6Nevn2GOsiVLljB+/HhG\njx7NE088QUZGBgAlJSW89dZbjBgxgvDwcGbNmkVWVhYff/wx+/btIyIigoULFwKlsxREREQwatQo\npk2bxvnz54HSO9XDw8N54403mDBhAn/88cfdfByisVBC1DEhISHqzJkz5bZptVoVERGh4uLilFJK\n5ebmqrCwMMPjjh07quXLlyullDp06JDq16+fUkqp9PR0FRAQoBITE5VSSn311VfKx8dH5efnG8oV\nFGMkQHAAAAOxSURBVBQYjvPSSy+pBx98UBUXF6vi4mI1fPhwFRUVVSHG06dPq3vvvVelp6crpZT6\n4IMP1HPPPaeUUiopKUn17t270vdWUFCgevfurS5fvqyUUmr79u1q6tSpSiml1q5dq1599VXDa7//\n/ns1Z84cpZRSH3/8sZo1a5bSarVKKWUov3r1avXMM88YymRkZKjevXur2NhYpZRSK1euVOPGjVNK\nKRUdHa06d+6sjh49WmlsQlRGWhKiXkhISCAuLo7Zs2cTHh7OQw89RElJSbmVCu+77z4AfH19SUtL\no7i4mKNHj9K1a1datmwJwNixYyvsW103M83gwYMxNzfH3Nyczp07k5iYWKFMdHQ0wcHBODk5ATBh\nwgSioqJu+T6srKwYNGgQGzZsAEonYLu6eND27dvZu3cv4eHhhIeH8/3333Px4kWgdKr0KVOmYGpq\nCmCY9fR6R48epVOnTrRr1w6AMWPGcOrUKfLz8wFo3bo13bt3v2WcQlxVp6cKF+IqpRTNmjVjzZo1\nlT6v0WiwtLQEMCw2o9PpKiSA6x9XxsLCwvC3qalppesjK6UqzMV/9bi3Eh4ezptvvsn999/P/v37\neeeddwz7fOKJJxg9enSlx6uKyuK69rGNjU2V9iPEVdKSEPVC27ZtsbKyYu3atYZtcXFx5OXlARW/\nRK8+9vX15cSJE4br8tf2YwA0adKk3ELxVRUUFMTOnTsNfQY//vgjffr0qXD8yvj7+5Obm8t7771H\naGioIbkNHDiQ77//nuzsbACKi4uJiYkBICQkhG+++cbQyX511bkmTZoY+lKuvt9Tp04ZphFfvXo1\nnTt3luQg7pi0JESdo9FoePjhhzEzMzP8Ml6/fj2fffYZCxcu5Msvv0Sn0+Hs7MwHH3xgKHP9PqB0\nAZZ//etfzJgxA0dHR4KDgzEzM8Pa2hqAadOmMWXKFKytrfn222+rHKO3tzezZ8/m4YcfxsTEhJYt\nW7JgwYIKx7+R8PBwPvroo/9v5w5xGASiIAwPBoMhHADNBRCcgtUEzQWQSByChAOgSHB4joVBLqlo\ngnumadK0/T/51LrZyeat1nW9Z2VZ6jgO1XWtIAh0XZeqqlKWZWqaRuM4yjmnMAyVpqmmaVJRFJrn\nWc455Xmurus0DIPatpX3XkmS3E0FeAVfhePnneepKIokPW/W27a9ZRcC+Ac0Cfy8ZVm077u894rj\nWH3ff/pIwNegSQAATDxcAwBMhAQAwERIAABMhAQAwERIAABMhAQAwPQAVSnSA55bZkwAAAAASUVO\nRK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b8e3bd90\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(eager_means)\n", + "plt.ylabel('Time(s)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (Eager)')\n", + "_ = plt.ylim(ymin=0)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Autograph vs. Eager vs Graph sum", + "provenance": [ + { + "file_id": "1olZkm32B7n7pQwlIAXR0_w8fZhRHCtkX", + "timestamp": 1531755808890 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index f7fe3de5dabdebe255210a7b6247809dbd70d10b..4729c735c621e68df30acfd4738d89874c3c55ac 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -23,7 +23,6 @@ from functools import wraps from enum import Enum # pylint:disable=g-bad-import-order -import gast import six # pylint:enable=g-bad-import-order @@ -69,7 +68,8 @@ def convert(recursive=False, verbose=False, arg_types=None): @wraps(f) def wrapper(*args, **kwargs): - return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + return converted_call(f, recursive, verbose, True, arg_types, *args, + **kwargs) wrapper = tf_decorator.make_decorator(f, wrapper) @@ -130,12 +130,12 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): return decorator -def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): +def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, + **kwargs): """Compiles a function call inline.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. - - if conversion.is_whitelisted_for_graph(f): + if not force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value @@ -245,37 +245,41 @@ def to_graph(e, _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) - module = gast.Module([]) + nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): - module.body.append(dep) - compiled_node, compiled_src = compiler.ast_to_object( - module, source_prefix=program_ctx.required_imports) + nodes.extend(dep) + compiled_module, compiled_src = compiler.ast_to_object( + nodes, + source_prefix=program_ctx.required_imports, + include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. - if key not in compiled_node.__dict__: - compiled_node.__dict__[key] = val - compiled_fn = getattr(compiled_node, name) + if key not in compiled_module.__dict__: + compiled_module.__dict__[key] = val + compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. + # TODO(mdan): Record this statically in the generated code. + # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' - if getattr(compiled_fn, source_map_attribute_name, None) is not None: + if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % - (compiled_fn, source_map_attribute_name)) - setattr(compiled_fn, source_map_attribute_name, - compiled_node.__dict__['ag_source_map__']) + (compiled, source_map_attribute_name)) + setattr(compiled, source_map_attribute_name, + compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) - return compiled_fn + return compiled def to_code(e, @@ -308,7 +312,7 @@ def to_code(e, conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( - compiler.ast_to_source(dep, indentation)[0] + compiler.ast_to_source(dep, indentation) for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 4de7df657204db2f625098d15e475f942eb352b8..803fde9089b1c004d9bfc0dfefd3d6b422752f0a 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -183,8 +183,8 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.converted_call(self.called_member, False, False, {}, self, - a) + x //= api.converted_call(self.called_member, False, False, False, {}, + self, a) return x tc = TestClass() @@ -195,7 +195,7 @@ class ApiTest(test.TestCase): self.assertListEqual([0, 1], sess.run(x).tolist()) def test_converted_call_builtin(self): - x = api.converted_call(range, False, False, {}, 3) + x = api.converted_call(range, False, False, False, {}, 3) self.assertEqual((0, 1, 2), tuple(x)) def test_converted_call_function(self): @@ -206,7 +206,7 @@ class ApiTest(test.TestCase): return x with self.test_session() as sess: - x = api.converted_call(test_fn, False, False, {}, + x = api.converted_call(test_fn, False, False, False, {}, constant_op.constant(-1)) self.assertEqual(1, sess.run(x)) @@ -224,7 +224,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc.test_method, False, False, {}, tc) + x = api.converted_call(tc.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_method_by_class(self): @@ -241,7 +241,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(TestClass.test_method, False, False, {}, tc) + x = api.converted_call(TestClass.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_callable_object(self): @@ -258,7 +258,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc, False, False, {}) + x = api.converted_call(tc, False, False, False, {}) self.assertEqual(1, sess.run(x)) def test_converted_call_constructor(self): @@ -274,12 +274,27 @@ class ApiTest(test.TestCase): return self.x with self.test_session() as sess: - tc = api.converted_call(TestClass, False, False, {}, + tc = api.converted_call(TestClass, False, False, False, {}, constant_op.constant(-1)) # tc is now a converted object. x = tc.test_method() self.assertEqual(1, sess.run(x)) + def test_converted_call_already_converted(self): + + def f(x): + return x == 0 + + with self.test_session() as sess: + x = api.converted_call(f, False, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + + converted_f = api.to_graph(f) + x = api.converted_call(converted_f, False, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + def test_to_graph_basic(self): def test_fn(x, s): diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 7bd0ba3f2dd2313552f0e184c3301a426f90dce2..fc8a976d3f3ecdc9c6339995dd0dfc776824b90d 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -48,6 +48,7 @@ from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -70,6 +71,8 @@ def is_whitelisted_for_graph(o): for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): return True + if hasattr(o, 'autograph_info__'): + return True return False @@ -115,12 +118,32 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) + # TODO(mdan,yashkatariya): Remove when object conversion is implemented. + elif hasattr(o, '__class__'): + raise NotImplementedError( + 'Object conversion is not yet supported. If you are ' + 'trying to convert code that uses an existing object, ' + 'try including the creation of that object in the ' + 'conversion. For example, instead of converting the method ' + 'of a class, try converting the entire class instead. ' + 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' + 'contrib/autograph/README.md#using-the-functional-api ' + 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) + # TODO(mdan): This is temporary. it should be created using a converter. + # TODO(mdan): The attribute should be added with a helper, not directly. + # The helper can ensure there are no collisions. + template = ''' + entity.autograph_info__ = {} + ''' + node.extend(templates.replace(template, entity=name)) + program_ctx.add_to_cache(o, node) + if program_ctx.recursive: while True: candidate = None @@ -164,21 +187,21 @@ def class_to_graph(c, program_ctx): class_namespace = namespace else: class_namespace.update(namespace) - converted_members[m] = node + converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. - # Process any base classes: if the sueprclass if of a whitelisted type, an + # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} - bases = [] + base_names = [] for base in c.__bases__: if isinstance(object, base): - bases.append('object') + base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) @@ -190,28 +213,28 @@ def class_to_graph(c, program_ctx): else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) - bases.append(alias) + base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. - output_nodes.append( - gast.ClassDef( - class_name, - bases=bases, - keywords=[], - body=list(converted_members.values()), - decorator_list=[])) - node = gast.Module(output_nodes) - + bases = [gast.Name(n, gast.Load(), None) for n in base_names] + class_def = gast.ClassDef( + class_name, + bases=bases, + keywords=[], + body=list(converted_members.values()), + decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. - node = qual_names.resolve(node) + class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) - node = ast_util.rename_symbols(node, renames) + class_def = ast_util.rename_symbols(class_def, renames) - return node, class_name, class_namespace + output_nodes.append(class_def) + + return output_nodes, class_name, class_namespace def _add_reserved_symbol(namespace, name, entity): @@ -268,18 +291,18 @@ def function_to_graph(f, context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context, rewrite_errors=rewrite_errors) - # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py + # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') - node.name = new_name + program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. - return node, new_name, namespace + return [node], new_name, namespace def node_to_graph(node, context, rewrite_errors=True): diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index 207225a1acec7db64cc121d7451e4bf34dabc7e7..86432573a719ea3f2b163746996dbf3301785a91 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -50,7 +50,7 @@ class ConversionTest(test.TestCase): self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) def test_entity_to_graph_unsupported_types(self): - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): program_ctx = self._simple_program_ctx() conversion.entity_to_graph('dummy', program_ctx, None, None) @@ -60,10 +60,11 @@ class ConversionTest(test.TestCase): return a + b program_ctx = self._simple_program_ctx() - ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - self.assertTrue(isinstance(ast, gast.FunctionDef), ast) + nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) + fn_node, _ = nodes + self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) - self.assertTrue(ns['b'] is b) + self.assertIs(ns['b'], b) def test_entity_to_graph_call_tree(self): @@ -78,14 +79,11 @@ class ConversionTest(test.TestCase): self.assertTrue(f in program_ctx.dependency_cache) self.assertTrue(g in program_ctx.dependency_cache) - self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) - # need one extra .body[0] in order to step past the try/except wrapper that - # is added automatically, the other for the with tf.name_scope('f') that is - # added automatically - self.assertEqual( - 'tf__g', - program_ctx.dependency_cache[f].body[0].body[0].body[0].value.func.id) - self.assertEqual('tf__g', program_ctx.dependency_cache[g].name) + f_node = program_ctx.dependency_cache[f][0] + g_node = program_ctx.dependency_cache[g][0] + self.assertEqual('tf__f', f_node.name) + self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id) + self.assertEqual('tf__g', g_node.name) def test_entity_to_graph_class_hierarchy(self): @@ -117,10 +115,12 @@ class ConversionTest(test.TestCase): self.assertTrue(TestBase in program_ctx.dependency_cache) self.assertTrue(TestSubclass in program_ctx.dependency_cache) + # The returned nodes will include: + # , , self.assertEqual('TfTestBase', - program_ctx.dependency_cache[TestBase].body[-1].name) + program_ctx.dependency_cache[TestBase][-2].name) self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -139,10 +139,11 @@ class ConversionTest(test.TestCase): self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( - 'Model', - program_ctx.dependency_cache[TestSubclass].body[0].names[0].name) + 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name) + # The returned nodes will include: + # , , self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_lambda(self): f = lambda a: a diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 988df70157170ed0a9ece33976e871e6f7693bbc..9909e521644a7a901653dc09853222167828c75c 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -141,7 +141,7 @@ def _dataset_for_stmt(ds, extra_test, body, init_state): while_body, init_state=(epoch_number, iterate) + init_state, extra_deps=()) - # Dropping the epoch number and iterate because they are not not syntactically + # Dropping the epoch number and iterate because they are not syntactically # visible. results = results[2:] @@ -212,12 +212,12 @@ def if_stmt(cond, body, orelse): Tuple containing the statement outputs. """ if tensor_util.is_tensor(cond): - return _tf_if_stmt(cond, body, orelse) + return tf_if_stmt(cond, body, orelse) else: return _py_if_stmt(cond, body, orelse) -def _tf_if_stmt(cond, body, orelse): +def tf_if_stmt(cond, body, orelse): """Overload of if_stmt that stages a TF cond.""" return control_flow_ops.cond(cond, body, orelse) diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index f77a6ab3928e4f933f4d21abba2030d4d6f8ec0a..ddadc6b96e8eb5417bfa1676ae304f7cbdedd92b 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -99,6 +99,16 @@ py_test( ], ) +py_test( + name = "origin_info_test", + srcs = ["origin_info_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "parser_test", srcs = ["parser_test.py"], diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py index 86e3f56a64d5300d925bc7fa31eaf69cd5e487a5..d7453b078197cd6f1c0521b861e96dd28d287cab 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -20,7 +20,6 @@ from __future__ import print_function import ast -import collections import gast from tensorflow.contrib.autograph.pyct import anno @@ -185,6 +184,7 @@ class PatternMatcher(gast.NodeVisitor): if v != p: return self.no_match() + def matches(node, pattern): """Basic pattern matcher for AST. @@ -253,30 +253,61 @@ def apply_to_single_assignments(targets, values, apply_fn): apply_fn(target, values) -def iter_fields(node): - for field in sorted(node._fields): - try: - yield getattr(node, field) - except AttributeError: - pass - - -def iter_child_nodes(node): - for field in iter_fields(node): - if isinstance(field, gast.AST): - yield field - elif isinstance(field, list): - for item in field: - if isinstance(item, gast.AST): - yield item - - -def parallel_walk(node_a, node_b): - todo_a = collections.deque([node_a]) - todo_b = collections.deque([node_b]) - while todo_a and todo_b: - node_a = todo_a.popleft() - node_b = todo_b.popleft() - todo_a.extend(iter_child_nodes(node_a)) - todo_b.extend(iter_child_nodes(node_b)) - yield node_a, node_b +def parallel_walk(node, other): + """Walks two ASTs in parallel. + + The two trees must have identical structure. + + Args: + node: Union[ast.AST, Iterable[ast.AST]] + other: Union[ast.AST, Iterable[ast.AST]] + Yields: + Tuple[ast.AST, ast.AST] + Raises: + ValueError: if the two trees don't have identical structure. + """ + if isinstance(node, (list, tuple)): + node_stack = list(node) + else: + node_stack = [node] + + if isinstance(other, (list, tuple)): + other_stack = list(other) + else: + other_stack = [other] + + while node_stack and other_stack: + assert len(node_stack) == len(other_stack) + n = node_stack.pop() + o = other_stack.pop() + + if (not isinstance(n, (ast.AST, gast.AST)) or + not isinstance(o, (ast.AST, gast.AST)) or + n.__class__.__name__ != o.__class__.__name__): + raise ValueError('inconsistent nodes: {} and {}'.format(n, o)) + + yield n, o + + for f in n._fields: + n_child = getattr(n, f, None) + o_child = getattr(o, f, None) + if f.startswith('__') or n_child is None or o_child is None: + continue + + if isinstance(n_child, (list, tuple)): + if (not isinstance(o_child, (list, tuple)) or + len(n_child) != len(o_child)): + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) + node_stack.extend(n_child) + other_stack.extend(o_child) + + elif isinstance(n_child, (gast.AST, ast.AST)): + node_stack.append(n_child) + other_stack.append(o_child) + + elif n_child != o_child: + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py index 981e398b930e232a563b439659a35376c2995f6c..2293c89720a54f7495670c6f28b00f716cad70db 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -44,7 +44,7 @@ class AstUtilTest(test.TestCase): node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) - source, _ = compiler.ast_to_source(node) + source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_a + b') def test_rename_symbols_attributes(self): @@ -54,7 +54,7 @@ class AstUtilTest(test.TestCase): node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) - source, _ = compiler.ast_to_source(node) + source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d') def test_rename_symbols_annotations(self): @@ -97,10 +97,10 @@ class AstUtilTest(test.TestCase): d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. - output = parser.parse_str('b = 3') - output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) - result, _ = compiler.ast_to_object(output) - self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) + node = parser.parse_str('def f(b): pass').body[0] + node.body.append(ast.Return(d)) + result, _ = compiler.ast_to_object(node) + self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'}) def assertMatch(self, target_str, pattern_str): node = parser.parse_expression(target_str) @@ -130,8 +130,8 @@ class AstUtilTest(test.TestCase): 'super(Bar, _).__init__(_)') def _mock_apply_fn(self, target, source): - target, _ = compiler.ast_to_source(target) - source, _ = compiler.ast_to_source(source) + target = compiler.ast_to_source(target) + source = compiler.ast_to_source(source) self._invocation_counts[(target.strip(), source.strip())] += 1 def test_apply_to_single_assignments_dynamic_unpack(self): @@ -157,24 +157,40 @@ class AstUtilTest(test.TestCase): }) def test_parallel_walk(self): - ret = ast.Return( - ast.BinOp( - op=ast.Add(), - left=ast.Name(id='a', ctx=ast.Load()), - right=ast.Num(1))) - node = ast.FunctionDef( - name='f', - args=ast.arguments( - args=[ast.Name(id='a', ctx=ast.Param())], - vararg=None, - kwarg=None, - defaults=[]), - body=[ret], - decorator_list=[], - returns=None) + node = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) for child_a, child_b in ast_util.parallel_walk(node, node): self.assertEqual(child_a, child_b) + def test_parallel_walk_inconsistent_trees(self): + node_1 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + node_2 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + (a * 2) + """)) + node_3 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 2 + """)) + with self.assertRaises(ValueError): + for _ in ast_util.parallel_walk(node_1, node_2): + pass + # There is not particular reason to reject trees that differ only in the + # value of a constant. + # TODO(mdan): This should probably be allowed. + with self.assertRaises(ValueError): + for _ in ast_util.parallel_walk(node_1, node_3): + pass + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py index 25fec7fd532542c53eea4096ea2a13aabfa290db..ba51dcf285036220e01b89e8beeb9aec8ffe36be 100644 --- a/tensorflow/contrib/autograph/pyct/cfg.py +++ b/tensorflow/contrib/autograph/pyct/cfg.py @@ -67,10 +67,8 @@ class Node(object): if isinstance(self.ast_node, gast.FunctionDef): return 'def %s' % self.ast_node.name elif isinstance(self.ast_node, gast.withitem): - source, _ = compiler.ast_to_source(self.ast_node.context_expr) - return source.strip() - source, _ = compiler.ast_to_source(self.ast_node) - return source.strip() + return compiler.ast_to_source(self.ast_node.context_expr).strip() + return compiler.ast_to_source(self.ast_node).strip() class Graph( diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD index ca1441cf6f8bb034c95b37fcdd9e8158d1db2e39..a0938b3e5f0e52532f63fea6fb4c3e478fc51d93 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD +++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD @@ -24,6 +24,7 @@ py_library( deps = [ "//tensorflow/contrib/autograph/pyct", "@gast_archive//:gast", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py index cc039986c219db1febfe610a5078e26eeb2d5a83..e42f679cfe31f919e10f7baf409247014b3cf386 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Conversion to A-normal form.""" +"""Conversion to A-normal form. + +The general idea of A-normal form is that every intermediate value is +explicitly named with a variable. For more, see +https://en.wikipedia.org/wiki/A-normal_form. + +The specific converters used here are based on Python AST semantics as +documented at https://greentreesnakes.readthedocs.io/en/latest/. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast +import six + +from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer @@ -32,26 +44,375 @@ class DummyGensym(object): # * the symbols generated so far self._idx = 0 - def new_name(self, stem): + def new_name(self, stem='tmp'): self._idx += 1 return stem + '_' + str(1000 + self._idx) class AnfTransformer(transformer.Base): - """Performs the actual conversion.""" + """Performs the conversion to A-normal form (ANF).""" - # TODO(mdan): Link to a reference. - # TODO(mdan): Implement. + # The algorithm is a postorder recursive tree walk. Any given node A may, in + # general, require creation of a series B of Assign statements, which compute + # and explicitly name the intermediate values needed to compute the value of + # A. If A was already a statement, it can be replaced with the sequence B + + # [A]. If A was an expression, B needs to be propagated up the tree until a + # statement is encountered. Since the `ast.NodeTransformer` framework makes + # no provision for subtraversals returning side information, this class + # accumulates the sequence B in an instance variable. - def __init__(self, entity_info): - """Creates a transformer. + # The only other subtlety is that some Python statements (like `if`) have both + # expression fields (`test`) and statement list fields (`body` and `orelse`). + # Any additional assignments needed to name all the intermediate values in the + # `test` can be prepended to the `if` node, but assignments produced by + # processing the `body` and the `orelse` need to be kept together with them, + # and not accidentally lifted out of the `if`. + + def __init__(self, entity_info, gensym_source=None): + """Creates an ANF transformer. Args: entity_info: transformer.EntityInfo + gensym_source: An optional object with the same interface as `DummyGensym` + for generating unique names """ super(AnfTransformer, self).__init__(entity_info) - self._gensym = DummyGensym(entity_info) + if gensym_source is None: + self._gensym = DummyGensym(entity_info) + else: + self._gensym = gensym_source(entity_info) + self._pending_statements = [] + + def _consume_pending_statements(self): + ans = self._pending_statements + self._pending_statements = [] + return ans + + def _add_pending_statement(self, stmt): + self._pending_statements.append(stmt) + + _trivial_nodes = ( + # Non-nodes that show up as AST fields + bool, six.string_types, + # Leaf nodes that are already in A-normal form + gast.expr_context, gast.Name, gast.Num, gast.Str, gast.Bytes, + gast.NameConstant, gast.Ellipsis, + # Binary operators + gast.Add, gast.Sub, gast.Mult, gast.Div, gast.Mod, gast.Pow, gast.LShift, + gast.RShift, gast.BitOr, gast.BitXor, gast.BitAnd, gast.FloorDiv, + # Unary operators + gast.Invert, gast.Not, gast.UAdd, gast.USub, + # Comparison operators + gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE, + gast.Is, gast.IsNot, gast.In, gast.NotIn, + ) + + def _is_node_trivial(self, node): + if node is None: + return True + elif isinstance(node, self._trivial_nodes): + return True + elif isinstance(node, gast.keyword): + return self._is_node_trivial(node.value) + elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): + return self._are_children_trivial(node) + return False + + def _are_children_trivial(self, node): + for field in node._fields: + if not field.startswith('__'): + if not self._is_node_trivial(getattr(node, field)): + return False + return True + + def _ensure_node_is_trivial(self, node): + if node is None: + return node + elif isinstance(node, self._trivial_nodes): + return node + elif isinstance(node, list): + # If something's field was actually a list, e.g., variadic arguments. + return [self._ensure_node_is_trivial(n) for n in node] + elif isinstance(node, gast.keyword): + node.value = self._ensure_node_is_trivial(node.value) + return node + elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): + return self._ensure_fields_trivial(node) + elif isinstance(node, gast.expr): + temp_name = self._gensym.new_name() + temp_assign = templates.replace( + 'temp_name = expr', temp_name=temp_name, expr=node)[0] + self._add_pending_statement(temp_assign) + answer = templates.replace('temp_name', temp_name=temp_name)[0] + return answer + else: + raise ValueError('Do not know how to treat {}'.format(node)) + + def _ensure_fields_trivial(self, node): + for field in node._fields: + if field.startswith('__'): + continue + setattr(node, field, self._ensure_node_is_trivial(getattr(node, field))) + return node + + def _visit_strict_statement(self, node, trivialize_children=True): + assert not self._pending_statements + node = self.generic_visit(node) + if trivialize_children: + self._ensure_fields_trivial(node) + results = self._consume_pending_statements() + results.append(node) + return results + + def _visit_strict_expression(self, node): + node = self.generic_visit(node) + self._ensure_fields_trivial(node) + return node + + # Note on code order: These are listed in the same order as the grammar + # elements on https://github.com/serge-sans-paille/gast + + # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default. + + def visit_Return(self, node): + return self._visit_strict_statement(node) + + def visit_Delete(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_Assign(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_AugAssign(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_Print(self, node): + return self._visit_strict_statement(node) + + def visit_For(self, node): + assert not self._pending_statements + # It's important to visit node.iter first, because any statements created + # thereby need to live outside the body. + self.visit(node.iter) + node.iter = self._ensure_node_is_trivial(node.iter) + iter_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.iter, but that is both correct and + # cheap because by this point node.iter is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + iter_stmts.append(node) + return iter_stmts + + def visit_AsyncFor(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial AsyncFor nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_While(self, node): + if not self._is_node_trivial(node.test): + msg = ('While with nontrivial test not supported yet ' + '(need to avoid precomputing the test).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_If(self, node): + assert not self._pending_statements + # It's important to visit node.test first, because any statements created + # thereby need to live outside the body. + self.visit(node.test) + node.test = self._ensure_node_is_trivial(node.test) + condition_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.test, but that is both correct and + # cheap because by this point node.test is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + condition_stmts.append(node) + return condition_stmts + + def visit_With(self, node): + assert not self._pending_statements + # It's important to visit node.items first, because any statements created + # thereby need to live outside the body. + for item in node.items: + self.visit(item) + node.items = [self._ensure_node_is_trivial(n) for n in node.items] + contexts_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.items, but that is both correct and + # cheap because by this point node.items is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + contexts_stmts.append(node) + return contexts_stmts + + def visit_AsyncWith(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial AsyncWith nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Raise(self, node): + return self._visit_strict_statement(node) + + # Try should be correct by default. + + def visit_Assert(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Assert nodes not supported yet ' + '(need to avoid computing the test when assertions are off, and ' + 'avoid computing the irritant when the assertion does not fire).') + raise ValueError(msg) + return self.generic_visit(node) + + # Import and ImportFrom should be correct by default. + + def visit_Exec(self, node): + return self._visit_strict_statement(node) + + # Global and Nonlocal should be correct by default. + + def visit_Expr(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + # Pass, Break, and Continue should be correct by default. + + def visit_BoolOp(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial BoolOp nodes not supported yet ' + '(need to preserve short-circuiting semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_BinOp(self, node): + return self._visit_strict_expression(node) + + def visit_UnaryOp(self, node): + return self._visit_strict_expression(node) + + def visit_Lambda(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Lambda nodes not supported ' + '(cannot insert statements into lambda bodies).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_IfExp(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial IfExp nodes not supported yet ' + '(need to convert to If statement, to evaluate branches lazily ' + 'and insert statements into them).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Dict(self, node): + return self._visit_strict_expression(node) + + def visit_Set(self, node): + return self._visit_strict_expression(node) + + def visit_ListComp(self, node): + msg = ('ListComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_SetComp(self, node): + msg = ('SetComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_DictComp(self, node): + msg = ('DictComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_GeneratorExp(self, node): + msg = ('GeneratorExp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_Await(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Await nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Yield(self, node): + return self._visit_strict_expression(node) + + def visit_YieldFrom(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial YieldFrom nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Compare(self, node): + if len(node.ops) > 1: + msg = ('Multi-ary compare nodes not supported yet ' + '(need to preserve short-circuiting semantics).') + raise ValueError(msg) + return self._visit_strict_expression(node) + + def visit_Call(self, node): + return self._visit_strict_expression(node) + + def visit_Repr(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Repr nodes not supported yet ' + '(need to research their syntax and semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_FormattedValue(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial FormattedValue nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_JoinedStr(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial JoinedStr nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Attribute(self, node): + return self._visit_strict_expression(node) + + def visit_Subscript(self, node): + return self._visit_strict_expression(node) + + # Starred and Name are correct by default, because the right thing to do is to + # just recur. + + def visit_List(self, node): + return self._visit_strict_expression(node) + + def visit_Tuple(self, node): + return self._visit_strict_expression(node) + + +def transform(node, entity_info, gensym_source=None): + """Converts the given node to A-normal form (ANF). + + The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form + The specific converters used here are based on Python AST semantics as + documented at https://greentreesnakes.readthedocs.io/en/latest/. -def transform(node, entity_info): - return AnfTransformer(entity_info).visit(node) + Args: + node: The node to transform. + entity_info: transformer.EntityInfo. TODO(mdan): What information does this + argument provide? + gensym_source: An optional object with the same interface as `DummyGensym` + for generating unique names. + """ + return AnfTransformer(entity_info, gensym_source=gensym_source).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py index 81983a5ecb7b8c6216285409f854e27b7154a08b..951974820c784974cb5bb2320adbb2b07f9332df 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import textwrap + from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer @@ -25,6 +27,22 @@ from tensorflow.contrib.autograph.pyct.common_transformers import anf from tensorflow.python.platform import test +class DummyGensym(object): + """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" + + def __init__(self, entity_info): + del entity_info + # A proper implementation needs to account for: + # * entity_info.namespace + # * all the symbols defined in the AST + # * the symbols generated so far + self._idx = 0 + + def new_name(self, stem='tmp'): + self._idx += 1 + return stem + '_' + str(1000 + self._idx) + + class AnfTransformerTest(test.TestCase): def _simple_source_info(self): @@ -37,17 +55,349 @@ class AnfTransformerTest(test.TestCase): owner_type=None) def test_basic(self): - def test_function(): a = 0 return a - node, _ = parser.parse_entity(test_function) - node = anf.transform(node, self._simple_source_info()) + node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) - self.assertEqual(test_function(), result.test_function()) + def assert_same_ast(self, expected_node, node, msg=None): + expected_source = compiler.ast_to_source(expected_node, indentation=' ') + expected_str = textwrap.dedent(expected_source).strip() + got_source = compiler.ast_to_source(node, indentation=' ') + got_str = textwrap.dedent(got_source).strip() + self.assertEqual(expected_str, got_str, msg=msg) + + def assert_body_anfs_as_expected(self, expected_fn, test_fn): + # Testing the code bodies only. Wrapping them in functions so the + # syntax highlights nicely, but Python doesn't try to execute the + # statements. + exp_node, _ = parser.parse_entity(expected_fn) + node, _ = parser.parse_entity(test_fn) + node = anf.transform( + node, self._simple_source_info(), gensym_source=DummyGensym) + exp_name = exp_node.body[0].name + # Ignoring the function names in the result because they can't be + # the same (because both functions have to exist in the same scope + # at the same time). + node.body[0].name = exp_name + self.assert_same_ast(exp_node, node) + # Check that ANF is idempotent + node_repeated = anf.transform( + node, self._simple_source_info(), gensym_source=DummyGensym) + self.assert_same_ast(node_repeated, node) + + def test_binop_basic(self): + + def test_function(x, y, z): + a = x + y + z + return a + + def expected_result(x, y, z): + tmp_1001 = x + y + a = tmp_1001 + z + return a + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_if_basic(self): + + def test_function(a, b, c, e, f, g): + if a + b + c: + d = e + f + g + return d + + def expected_result(a, b, c, e, f, g): + tmp_1001 = a + b + tmp_1002 = tmp_1001 + c + if tmp_1002: + tmp_1003 = e + f + d = tmp_1003 + g + return d + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_nested_binop_and_return(self): + + def test_function(b, c, d, e): + return (2 * b + c) + (d + e) + + def expected_result(b, c, d, e): + tmp_1001 = 2 * b + tmp_1002 = tmp_1001 + c + tmp_1003 = d + e + tmp_1004 = tmp_1002 + tmp_1003 + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_function_call_and_expr(self): + + def test_function(call_something, a, b, y, z, c, d, e, f, g, h, i): + call_something(a + b, y * z, kwarg=c + d, *(e + f), **(g + h + i)) + + def expected_result(call_something, a, b, y, z, c, d, e, f, g, h, i): + tmp_1001 = g + h + tmp_1002 = a + b + tmp_1003 = y * z + tmp_1004 = e + f + tmp_1005 = c + d + tmp_1006 = tmp_1001 + i + call_something(tmp_1002, tmp_1003, kwarg=tmp_1005, *tmp_1004, **tmp_1006) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_with_and_print(self): + + def test_function(a, b, c): + with a + b + c as d: + print(2 * d + 1) + + def expected_result(a, b, c): + tmp_1001 = a + b + tmp_1002 = tmp_1001 + c + with tmp_1002 as d: + tmp_1003 = 2 * d + tmp_1004 = tmp_1003 + 1 + print(tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_local_definition_and_binary_compare(self): + + def test_function(): + def foo(a, b): + return 2 * a < b + return foo + + def expected_result(): + def foo(a, b): + tmp_1001 = 2 * a + tmp_1002 = tmp_1001 < b + return tmp_1002 + return foo + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_list_literal(self): + + def test_function(a, b, c, d, e, f): + return [a + b, c + d, e + f] + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a + b + tmp_1002 = c + d + tmp_1003 = e + f + tmp_1004 = [tmp_1001, tmp_1002, tmp_1003] + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_tuple_literal_and_unary(self): + + def test_function(a, b, c, d, e, f): + return (a + b, -(c + d), e + f) + + def expected_result(a, b, c, d, e, f): + tmp_1001 = c + d + tmp_1002 = a + b + tmp_1003 = -tmp_1001 + tmp_1004 = e + f + tmp_1005 = (tmp_1002, tmp_1003, tmp_1004) + return tmp_1005 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_set_literal(self): + + def test_function(a, b, c, d, e, f): + return set(a + b, c + d, e + f) + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a + b + tmp_1002 = c + d + tmp_1003 = e + f + tmp_1004 = set(tmp_1001, tmp_1002, tmp_1003) + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_dict_literal_and_repr(self): + + def test_function(foo, bar, baz): + return repr({foo + bar + baz: 7 | 8}) + + def expected_result(foo, bar, baz): + tmp_1001 = foo + bar + tmp_1002 = tmp_1001 + baz + tmp_1003 = 7 | 8 + tmp_1004 = {tmp_1002: tmp_1003} + tmp_1005 = repr(tmp_1004) + return tmp_1005 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_field_read_and_write(self): + + def test_function(a, d): + a.b.c = d.e.f + 3 + + def expected_result(a, d): + tmp_1001 = a.b + tmp_1002 = d.e + tmp_1003 = tmp_1002.f + tmp_1001.c = tmp_1003 + 3 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_subscript_read_and_write(self): + + def test_function(a, b, c, d, e, f): + a[b][c] = d[e][f] + 3 + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a[b] + tmp_1002 = d[e] + tmp_1003 = tmp_1002[f] + tmp_1001[c] = tmp_1003 + 3 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_augassign_and_delete(self): + + def test_function(a, x, y, z): + a += x + y + z + del a + del z[y][x] + + def expected_result(a, x, y, z): + tmp_1001 = x + y + a += tmp_1001 + z + del a + tmp_1002 = z[y] + del tmp_1002[x] + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_raise_yield_and_raise(self): + + def test_function(a, c, some_computed, exception): + yield a ** c + raise some_computed('complicated' + exception) + + def expected_result(a, c, some_computed, exception): + tmp_1001 = a ** c + yield tmp_1001 + tmp_1002 = 'complicated' + exception + tmp_1003 = some_computed(tmp_1002) + raise tmp_1003 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_with_and_if_with_expressions(self): + + def test_function(foo, bar, function, quux, quozzle, w, x, y, z): + with foo + bar: + function(x + y) + if quux + quozzle: + function(z / w) + + def expected_result(foo, bar, function, quux, quozzle, w, x, y, z): + tmp_1001 = foo + bar + with tmp_1001: + tmp_1002 = x + y + function(tmp_1002) + tmp_1003 = quux + quozzle + if tmp_1003: + tmp_1004 = z / w + function(tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_exec(self): + + def test_function(): + # The point is to test A-normal form conversion of exec + # pylint: disable=exec-used + exec('computed' + 5 + 'stuff', globals(), locals()) + + def expected_result(): + # pylint: disable=exec-used + tmp_1001 = 'computed' + 5 + tmp_1002 = tmp_1001 + 'stuff' + tmp_1003 = globals() + tmp_1004 = locals() + exec(tmp_1002, tmp_1003, tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_simple_while_and_assert(self): + + def test_function(foo, quux): + while foo: + assert quux + foo = foo + 1 * 3 + + def expected_result(foo, quux): + while foo: + assert quux + tmp_1001 = 1 * 3 + foo = foo + tmp_1001 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_for(self): + + def test_function(compute, something, complicated, foo): + for foo in compute(something + complicated): + bar = foo + 1 * 3 + return bar + + def expected_result(compute, something, complicated, foo): + tmp_1001 = something + complicated + tmp_1002 = compute(tmp_1001) + for foo in tmp_1002: + tmp_1003 = 1 * 3 + bar = foo + tmp_1003 + return bar + + self.assert_body_anfs_as_expected(expected_result, test_function) + + # This test collects several examples where the definition of A-normal form + # implemented by this transformer is questionable. Mostly it's here to spell + # out what the definition is in these cases. + def test_controversial(self): + + def test_function(b, c, d, f): + a = c + d + a.b = c + d + a[b] = c + d + a += c + d + a, b = c + a, b = c, d + a = f(c) + a = f(c + d) + a[b + d] = f.e(c + d) + + def expected_result(b, c, d, f): + a = c + d + a.b = c + d # Should be a.b = tmp? (Definitely not tmp = c + d) + a[b] = c + d # Should be a[b] = tmp? (Definitely not tmp = c + d) + a += c + d # Should be a += tmp? (Definitely not tmp = c + d) + a, b = c # Should be a = c[0], b = c[1]? Or not? + a, b = c, d # Should be a = c, b = d? Or not? + a = f(c) + tmp_1001 = c + d + a = f(tmp_1001) + tmp_1002 = b + d + tmp_1003 = f.e + tmp_1004 = c + d + a[tmp_1002] = tmp_1003(tmp_1004) # Or should be a[tmp1] = tmp2? + + self.assert_body_anfs_as_expected(expected_result, test_function) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py index c90a5e89c2d673a0bcfb4609d19c842cf7cc135d..f9cee109624dafd4da4a0981c5f8fda0a5d8a5e7 100644 --- a/tensorflow/contrib/autograph/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -30,44 +30,7 @@ import tempfile import astor import gast -from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import origin_info -from tensorflow.contrib.autograph.pyct import parser - - -def _build_source_map(node, code): - """Return the Python objects represented by given AST. - - Compiling the AST code this way ensures that the source code is readable by - e.g. `pdb` or `inspect`. - - Args: - node: An AST node of the original generated code, before the source code is - generated. - code: The string representation of the source code for the newly generated - code. - - Returns: - Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph - generated code. - """ - # After we have the final generated code we reparse it to get the final line - # numbers. Then we walk through the generated and original ASTs in parallel - # to build the mapping between the user and generated code. - new_node = parser.parse_str(code) - origin_info.resolve(new_node, code) - source_mapping = {} - for before, after in ast_util.parallel_walk(node, new_node): - # Need both checks because if origin information is ever copied over to new - # nodes then we need to rely on the fact that only the original user code - # has the origin annotation. - if (anno.hasanno(before, anno.Basic.ORIGIN) and - anno.hasanno(after, anno.Basic.ORIGIN)): - source_info = anno.getanno(before, anno.Basic.ORIGIN) - new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number - source_mapping[new_line_number] = source_info - return source_mapping def ast_to_source(node, indentation=' '): @@ -81,24 +44,28 @@ def ast_to_source(node, indentation=' '): code: The source code generated from the AST object source_mapping: A mapping between the user and AutoGraph generated code. """ - original_node = node - if isinstance(node, gast.AST): - node = gast.gast_to_ast(node) + if not isinstance(node, (list, tuple)): + node = (node,) generator = astor.codegen.SourceGenerator(indentation, False, astor.string_repr.pretty_string) - generator.visit(node) - generator.result.append('\n') + + for n in node: + if isinstance(n, gast.AST): + n = gast.gast_to_ast(n) + generator.visit(n) + generator.result.append('\n') + # In some versions of Python, literals may appear as actual values. This # ensures everything is string. code = map(str, generator.result) code = astor.source_repr.pretty_source(code).lstrip() - source_mapping = _build_source_map(original_node, code) - return code, source_mapping + return code -def ast_to_object(node, +def ast_to_object(nodes, indentation=' ', + include_source_map=False, source_prefix=None, delete_on_exit=True): """Return the Python objects represented by given AST. @@ -107,42 +74,46 @@ def ast_to_object(node, e.g. `pdb` or `inspect`. Args: - node: The code to compile, as an AST object. - indentation: The string to use for indentation. - source_prefix: Optional string to print as-is into the source file. - delete_on_exit: Whether to delete the temporary file used for compilation on - exit. + nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST + object. + indentation: Text, the string to use for indentation. + include_source_map: bool, whether to attach a source map to the compiled + object. Also see origin_info.py. + source_prefix: Optional[Text], string to print as-is into the source file. + delete_on_exit: bool, whether to delete the temporary file used for + compilation on exit. Returns: - compiled_node: A module object containing the compiled source code. + compiled_nodes: A module object containing the compiled source code. source: The source code of the compiled object Raises: ValueError: If ag_source_map__ is already in the namespace of the compiled - node. + nodes. """ - # code_source_mapping does not yet include the offsets from import statements. - source, code_source_mapping = ast_to_source(node, indentation=indentation) + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + + source = ast_to_source(nodes, indentation=indentation) + + if source_prefix: + source = source_prefix + '\n' + source with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: - # TODO(znado): move into an _offset_source_map() helper function. - # Need to offset the generated line numbers by the number of import lines. - if source_prefix: - num_import_lines = source_prefix.count('\n') + 1 - else: - num_import_lines = 0 - source_mapping = {} - for line_number, original_position in code_source_mapping.items(): - source_map_key = origin_info.CodeLocation( - file_path=f.name, line_number=line_number + num_import_lines) - source_mapping[source_map_key] = original_position module_name = os.path.basename(f.name[:-3]) - if source_prefix: - f.write(source_prefix) - f.write('\n') f.write(source) + + if isinstance(nodes, (list, tuple)): + indices = range(-len(nodes), 0) + else: + indices = (-1,) + + if include_source_map: + source_map = origin_info.source_map(nodes, source, f.name, indices) + + # TODO(mdan): Try flush() and delete=False instead. if delete_on_exit: atexit.register(lambda: os.remove(f.name)) - compiled_node = imp.load_source(module_name, f.name) + compiled_nodes = imp.load_source(module_name, f.name) # TODO(znado): Clean this up so we don't need to attach it to the namespace. # TODO(znado): This does not work for classes because their methods share a @@ -158,11 +129,13 @@ def ast_to_object(node, # is hard, and this cleanly fixes the # issues encountered with nested functions because this is attached to the # outermost one. - source_map_name = 'ag_source_map__' - if source_map_name in compiled_node.__dict__: - raise ValueError('cannot convert %s because is has namespace attribute ' - '"%s", which is reserved for AutoGraph.' % - (compiled_node, source_map_name)) - compiled_node.__dict__[source_map_name] = source_mapping - - return compiled_node, source + if include_source_map: + # TODO(mdan): This name should be decided by the caller. + source_map_name = 'ag_source_map__' + if source_map_name in compiled_nodes.__dict__: + raise ValueError('cannot convert %s because is has namespace attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled_nodes, source_map_name)) + compiled_nodes.__dict__[source_map_name] = source_map + + return compiled_nodes, source diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py index e29fa9324c6b742e8f60c1b6b3302c6fa1374a96..cf783da6a3e540c6901a5fe9a5e4afdb6b1cfc03 100644 --- a/tensorflow/contrib/autograph/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -59,7 +59,7 @@ class CompilerTest(test.TestCase): value=gast.Str('c')) ]) - source, _ = compiler.ast_to_source(node, indentation=' ') + source = compiler.ast_to_source(node, indentation=' ') self.assertEqual( textwrap.dedent(""" if 1: diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py index 614e346634ddc180ee2364407744537d725eb325..b60651a30e342dabe40cbcef1486826e16c2e2c7 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info.py +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -18,53 +18,122 @@ from __future__ import division from __future__ import print_function import collections +import tokenize import gast +import six from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.util import tf_inspect -class CodeLocation( - collections.namedtuple('CodeLocation', ('file_path', 'line_number'))): - """Location of a line of code. +class LineLocation( + collections.namedtuple('LineLocation', ('filename', 'lineno'))): + """Similar to Location, but without column information. Attributes: - file_path: text, the full path to the file containing the code. - line_number: Int, the 1-based line number of the code in its file. + filename: Text + lineno: int, 1-based """ pass +class Location( + collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))): + """Encodes code location information. + + Attributes: + filename: Text + lineno: int, 1-based + col_offset: int + """ + + @property + def line_loc(self): + return LineLocation(self.filename, self.lineno) + + class OriginInfo( - collections.namedtuple('OriginInfo', - ('file_path', 'function_name', 'line_number', - 'column_offset', 'source_code_line'))): + collections.namedtuple( + 'OriginInfo', + ('loc', 'function_name', 'source_code_line', 'comment'))): """Container for information about the source code before conversion. - Instances of this class contain information about the source code that - transformed code originated from. Examples include: - * line number - * file name - * original user code + Attributes: + loc: Location + function_name: Optional[Text] + source_code_line: Text + comment: Optional[Text] """ def as_frame(self): - """Makes a traceback frame tuple. - - Returns: - A tuple of (file_path, line_number, function_name, source_code_line). - """ - return (self.file_path, self.line_number, self.function_name, + """Returns a 4-tuple consistent with the return of traceback.extract_tb.""" + return (self.loc.filename, self.loc.lineno, self.function_name, self.source_code_line) +# TODO(mdan): This source map should be a class - easier to refer to. +def source_map(nodes, code, filename, indices_in_code): + """Creates a source map between an annotated AST and the code it compiles to. + + Args: + nodes: Iterable[ast.AST, ...] + code: Text + filename: Optional[Text] + indices_in_code: Union[int, Iterable[int, ...]], the positions at which + nodes appear in code. The parser always returns a module when parsing + code. This argument indicates the position in that module's body at + which the corresponding of node should appear. + + Returns: + Dict[CodeLocation, OriginInfo], mapping locations in code to locations + indicated by origin annotations in node. + """ + reparsed_nodes = parser.parse_str(code) + reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] + + resolve(reparsed_nodes, code) + result = {} + + for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): + # Note: generated code might not be mapped back to its origin. + # TODO(mdan): Generated code should always be mapped to something. + origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) + final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) + if origin_info is None or final_info is None: + continue + + line_loc = LineLocation(filename, final_info.loc.lineno) + + existing_origin = result.get(line_loc) + if existing_origin is not None: + # Overlaps may exist because of child nodes, but almost never to + # different line locations. Exception make decorated functions, where + # both lines are mapped to the same line in the AST. + + # Line overlaps: keep bottom node. + if existing_origin.loc.line_loc == origin_info.loc.line_loc: + if existing_origin.loc.lineno >= origin_info.loc.lineno: + continue + + # In case of overlaps, keep the leftmost node. + if existing_origin.loc.col_offset <= origin_info.loc.col_offset: + continue + + result[line_loc] = origin_info + + return result + + # TODO(znado): Consider refactoring this into a Visitor. -def resolve(node, source, function=None): +# TODO(mdan): Does this work correctly with inner functions? +def resolve(nodes, source, function=None): """Adds an origin information to all nodes inside the body of function. Args: - node: The AST node for the function whose body nodes will be annotated. + nodes: Union[ast.AST, Iterable[ast.AST, ...]] source: Text, the source code string for the function whose body nodes will be annotated. function: Callable, the function that will have all nodes inside of it @@ -76,25 +145,42 @@ def resolve(node, source, function=None): A tuple of the AST node for function and a String containing its source code. """ + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None + + # TODO(mdan): Pull this to a separate utility. + code_reader = six.StringIO(source) + comment_map = {} + for token in tokenize.generate_tokens(code_reader.readline): + tok_type, tok_string, loc, _, _ = token + srow, _ = loc + if tok_type == tokenize.COMMENT: + comment_map[srow] = tok_string.strip()[1:].strip() + source_lines = source.split('\n') - for n in gast.walk(node): - if hasattr(n, 'lineno'): - # n.lineno is relative to the start of the enclosing function, so need to - # offset it by the line of the function. - source_code_line = source_lines[n.lineno - 1] + for node in nodes: + for n in gast.walk(node): + if not hasattr(n, 'lineno'): + continue + + lineno_in_body = n.lineno + + source_code_line = source_lines[lineno_in_body - 1] if function: - source_lineno = n.lineno + function_lineno - 1 + source_lineno = function_lineno + lineno_in_body function_name = function.__name__ else: - source_lineno = n.lineno + source_lineno = lineno_in_body function_name = None - anno.setanno( - n, anno.Basic.ORIGIN, - OriginInfo(function_filepath, function_name, source_lineno, - n.col_offset, source_code_line)) + + location = Location(function_filepath, source_lineno, n.col_offset) + origin = OriginInfo(location, function_name, + source_code_line, comment_map.get(source_lineno)) + anno.setanno(n, anno.Basic.ORIGIN, origin) diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eeaa13007ea0ae331293c216a76352956c0ee9ec --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py @@ -0,0 +1,104 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for origin_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class OriginInfoTest(test.TestCase): + + def test_source_map(self): + + def test_fn(x): + if x > 0: + x += 1 + return x + + node, source = parser.parse_entity(test_fn) + fn_node = node.body[0] + origin_info.resolve(fn_node, source) + + # Insert a traced line. + new_node = parser.parse_str('x = abs(x)').body[0] + anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) + fn_node.body.insert(0, new_node) + + # Insert an untraced line. + fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) + + modified_source = compiler.ast_to_source(fn_node) + + source_map = origin_info.source_map(fn_node, modified_source, + 'test_filename', [0]) + + loc = origin_info.LineLocation('test_filename', 1) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertEqual(origin.loc.lineno, 1) + + # The untraced line, inserted second. + loc = origin_info.LineLocation('test_filename', 2) + self.assertFalse(loc in source_map) + + # The traced line, inserted first. + loc = origin_info.LineLocation('test_filename', 3) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, ' if x > 0:') + self.assertEqual(origin.loc.lineno, 2) + + loc = origin_info.LineLocation('test_filename', 4) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, ' if x > 0:') + self.assertEqual(origin.loc.lineno, 2) + + def test_resolve(self): + + def test_fn(x): + """Docstring.""" + return x # comment + + node, source = parser.parse_entity(test_fn) + fn_node = node.body[0] + origin_info.resolve(fn_node, source) + + origin = anno.getanno(fn_node, anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 1) + self.assertEqual(origin.loc.col_offset, 0) + self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertIsNone(origin.comment) + + origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 2) + self.assertEqual(origin.loc.col_offset, 2) + self.assertEqual(origin.source_code_line, ' """Docstring."""') + self.assertIsNone(origin.comment) + + origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 3) + self.assertEqual(origin.loc.col_offset, 2) + self.assertEqual(origin.source_code_line, ' return x # comment') + self.assertEqual(origin.comment, 'comment') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py index c961efa892df6a21804dae8f52ef64bf99cd409e..112ed46a1e487a7904e79267c1ce7db0ad914552 100644 --- a/tensorflow/contrib/autograph/pyct/parser.py +++ b/tensorflow/contrib/autograph/pyct/parser.py @@ -37,6 +37,7 @@ def parse_entity(entity): def parse_str(src): """Returns the AST of given piece of code.""" + # TODO(mdan): This should exclude the module things are autowrapped in. return gast.parse(src) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py index 9a84f1231cb71745f778285f30ada151a7c1accd..7f2b379d3de236020f1ec2b8a4972cc67b10b060 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py @@ -39,7 +39,7 @@ from tensorflow.contrib.autograph.pyct.static_analysis import annos class Definition(object): """Definition objects describe a unique definition of a variable. - Subclasses of this may be used by passing an appropriate factory fuction to + Subclasses of this may be used by passing an appropriate factory function to resolve. Attributes: diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index 72d1d3b269f1293fa1e90a40d46e9abf8720bf54..5831d57ceb58d4b291a4f52bbf4282e107104219 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -140,8 +140,8 @@ class ReplaceTransformer(gast.NodeTransformer): def _set_inner_child_context(self, node, ctx): if isinstance(node, gast.Attribute): - self._set_inner_child_context(node.value, ctx) - node.ctx = gast.Load() + self._set_inner_child_context(node.value, gast.Load()) + node.ctx = ctx elif isinstance(node, gast.Tuple): for e in node.elts: self._set_inner_child_context(e, ctx) diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py index a8bbc5a4de3bacce10dd0b8b0648bd0cb495e8c9..77e8ff62fd8665e095cfb410a2aa418e9f9bd52b 100644 --- a/tensorflow/contrib/autograph/pyct/templates_test.py +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -97,6 +97,19 @@ class TemplatesTest(test.TestCase): with self.assertRaises(ValueError): templates.replace(template, foo=1) + def test_replace_attribute_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace( + template, + foo=parser.parse_expression('a.b.c'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load) + self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load) + def test_replace_call_keyword(self): template = """ def test_fn(): diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..957db356f7e1acf673ce5db7c8087208af43ac23 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/BUILD @@ -0,0 +1,43 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "testing", + srcs = [ + "codegen.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", + "@gast_archive//:gast", + ], +) + +py_test( + name = "codegen_test", + size = "large", + srcs = ["codegen_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":testing", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen.py b/tensorflow/contrib/autograph/pyct/testing/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..279e7c09dc6449184e2029ad65fc3f71d94db8b4 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/codegen.py @@ -0,0 +1,234 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Random code generation for testing/fuzzing.""" +# pylint: disable=invalid-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import string + +import gast +import numpy as np + +from tensorflow.contrib.autograph.pyct import templates + + +class NodeSampler(object): + sample_map = None + + def sample(self): + nodes, magnitudes = zip(*self.sample_map.items()) + return np.random.choice( + nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes)) + + +class StatementSampler(NodeSampler): + sample_map = dict(( + (gast.Assign, 10), + (gast.Print, 1), + (gast.If, 2), + (gast.While, 2), + (gast.For, 0), + )) + + +class ExpressionSampler(NodeSampler): + sample_map = dict(( + (gast.UnaryOp, 1), + (gast.BinOp, 8), + (gast.Name, 1), + (gast.Call, 0), + )) + + +class CompareSampler(NodeSampler): + sample_map = dict(( + (gast.Eq, 1), + (gast.NotEq, 1), + (gast.Lt, 1), + (gast.LtE, 1), + (gast.Gt, 1), + (gast.GtE, 1), + (gast.Is, 1), + (gast.IsNot, 1), + )) + + +class BinaryOpSampler(NodeSampler): + sample_map = dict(( + (gast.Add, 1), + (gast.Sub, 1), + (gast.Mult, 1), + (gast.Div, 1), + (gast.FloorDiv, 1), + (gast.Mod, 1), + (gast.Pow, 1), + )) + + +class UnaryOpSampler(NodeSampler): + sample_map = dict(((gast.USub, 1), (gast.UAdd, 0))) + + +class NameSampler(NodeSampler): + sample_map = dict(( + ('new', 1), + ('existing', 1), + )) + + +N_CONTROLFLOW_STATEMENTS = 10 +N_FUNCTIONDEF_STATEMENTS = 10 + + +class CodeGenerator(object): + """Generate random syntactically-valid Python ASTs.""" + + def __init__(self, max_depth=3, depth=0): + self.max_depth = max_depth + self.depth = depth + + def generate_statement(self): + """Generate a statement node, dispatching to the correct class method.""" + desired_node = StatementSampler().sample() + self.depth += 1 + + # Enforce some constraints on generating statements. + # E.g., if statements need at least 3 readable variables. + # If we fail to satisfy our constraints, draw another sample. + if desired_node in (gast.While, gast.For, gast.If): + if self.depth > self.max_depth: + return self.generate_statement() + + # Go get the generator method and run it + method = 'generate_' + desired_node.__name__ + visitor = getattr(self, method) + node = visitor() + self.depth -= 1 + return node + + def sample_node_list(self, low, high, generator): + """Generate a list of statements of random length. + + Args: + low: Fewest number of statements to generate. + high: Highest number of statements to generate. + generator: Function to call to generate nodes. + + Returns: + A list of statements. + """ + statements = [] + for _ in range(np.random.randint(low, high)): + statements.append(generator()) + return statements + + def generate_Name(self, ctx=gast.Load()): + variable_name = '_' + ''.join( + random.choice(string.ascii_lowercase) for _ in range(4)) + return gast.Name(variable_name, ctx=ctx, annotation=None) + + def generate_BinOp(self): + # TODO(alexbw): convert to generate_expression when we get to limit + # expression depth. + op = BinaryOpSampler().sample()() + return gast.BinOp(self.generate_Name(), op, self.generate_Name()) + + def generate_Compare(self): + op = CompareSampler().sample()() + return gast.Compare(self.generate_Name(), [op], [self.generate_Name()]) + + def generate_UnaryOp(self): + operand = self.generate_Name() + op = UnaryOpSampler().sample()() + return gast.UnaryOp(op, operand) + + def generate_expression(self): + desired_node = ExpressionSampler().sample() + # Go get the generator method and run it + method = 'generate_' + desired_node.__name__ + generator = getattr(self, method) + return generator() + + def generate_Assign(self): + """Generate an Assign node.""" + # Generate left-hand side + target_node = self.generate_Name(gast.Store()) + # Generate right-hand side + value_node = self.generate_expression() + # Put it all together + node = gast.Assign(targets=[target_node], value=value_node) + return node + + def generate_If(self): + """Generate an If node.""" + test = self.generate_Compare() + + # Generate true branch statements + body = self.sample_node_list( + low=1, + high=N_CONTROLFLOW_STATEMENTS // 2, + generator=self.generate_statement) + + # Generate false branch statements + orelse = self.sample_node_list( + low=1, + high=N_CONTROLFLOW_STATEMENTS // 2, + generator=self.generate_statement) + + node = gast.If(test, body, orelse) + return node + + def generate_While(self): + """Generate a While node.""" + + test = self.generate_Compare() + body = self.sample_node_list( + low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement) + orelse = [] # not generating else statements + + node = gast.While(test, body, orelse) + return node + + def generate_Call(self): + raise NotImplementedError + + def generate_Return(self): + return gast.Return(self.generate_expression()) + + def generate_Print(self): + return templates.replace('print(x)', x=self.generate_expression())[0] + + def generate_FunctionDef(self): + """Generate a FunctionDef node.""" + + # Generate the arguments, register them as available + arg_vars = self.sample_node_list( + low=2, high=10, generator=lambda: self.generate_Name(gast.Param())) + args = gast.arguments(arg_vars, None, [], [], None, []) + + # Generate the function body + body = self.sample_node_list( + low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement) + body.append(self.generate_Return()) + fn_name = self.generate_Name().id + node = gast.FunctionDef(fn_name, args, body, (), None) + return node + + +def generate_random_functiondef(): + return CodeGenerator().generate_FunctionDef() diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py new file mode 100644 index 0000000000000000000000000000000000000000..255c3b2a2edc65ab978d8c32682fafd8ce00f5ac --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for type_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct.testing import codegen +from tensorflow.python.platform import test + + +class CodeGenTest(test.TestCase): + + def test_codegen_gens(self): + np.random.seed(0) + for _ in range(1000): + node = codegen.generate_random_functiondef() + fn = compiler.ast_to_object(node) + self.assertIsNotNone( + fn, 'Generated invalid AST that could not convert to source.') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 3e8906823ee63f032e1de00fd661212772ccd917..969ca12244148b346ba3160fba124384a9641a05 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -421,10 +421,13 @@ class Base(gast.NodeTransformer): source_file = self.entity_info.source_file did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) + processing_expr_node = False try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): did_enter_function = True + elif isinstance(node, gast.Expr): + processing_expr_node = True if did_enter_function: self._enclosing_entities.append(node) @@ -433,9 +436,23 @@ class Base(gast.NodeTransformer): self._lineno = node.lineno self._col_offset = node.col_offset + if processing_expr_node: + entry_expr_value = node.value + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): result = super(Base, self).visit(node) + # Adjust for consistency: replacing the value of an Expr with + # an Assign node removes the need for the Expr node. + if processing_expr_node: + if isinstance(result, gast.Expr) and result.value != entry_expr_value: + # When the replacement is a list, it is assumed that the list came + # from a template that contained a number of statements, which + # themselves are standalone and don't require an enclosing Expr. + if isinstance(result.value, + (list, tuple, gast.Assign, gast.AugAssign)): + result = result.value + # On exception, the local scope integrity is not guaranteed. if did_enter_function: self._enclosing_entities.pop() diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 71079cfdc04feaf26ab07b7dba193f745555433f..ccbe5fc9541dfad561d8eab730e2b15f6250ceb2 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -27,6 +27,7 @@ from tensorflow.contrib.autograph.utils import type_check from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops @@ -50,7 +51,9 @@ def dynamic_builtin(f, *args, **kwargs): def dynamic_len(list_or_tensor): """Implementation of len using dynamic dispatch.""" - if tensor_util.is_tensor(list_or_tensor): + if _is_tensor_list(list_or_tensor): + return list_ops.tensor_list_length(list_or_tensor) + elif tensor_util.is_tensor(list_or_tensor): shape = list_or_tensor.shape if not shape.ndims: raise ValueError( @@ -59,6 +62,11 @@ def dynamic_len(list_or_tensor): return len(list_or_tensor) +def _is_tensor_list(list_or_tensor): + return (tensor_util.is_tensor(list_or_tensor) + and list_or_tensor.dtype == dtypes.variant) + + def dynamic_int(num_or_tensor, **kwargs): """Implementation of int() using dynamic dispatch.""" if tensor_util.is_tensor(num_or_tensor): diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index d7c71a20ed4ba6a55dc0356ab5a3d096ed042e59..88a3909de4f34c11ac7ac3f0a865d76b675d0d06 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -324,7 +324,7 @@ If you encounter a log line that includes the following: "filename":"/usr/share/grpc/roots.pem" ``` -you likely need to copy the [gRPC roots.pem file][grpcPem] to +you likely need to copy the [gRPC `roots.pem` file][grpcPem] to `/usr/share/grpc/roots.pem` on your local machine. [grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem @@ -338,7 +338,10 @@ are available. - **Compute Engine**: When running on Compute Engine, the client will often use the service account from the virtual machine's metadata service. Be sure to authorize your Compute Engine VM to have access to the Cloud Bigtable service - when creating your VM. + when creating your VM, or [update the VM's scopes][update-vm-scopes] on a + running VM if you run into this issue. - **Cloud TPU**: Your Cloud TPUs run with the designated Cloud TPU service account dedicated to your GCP project. Ensure the service account has been authorized via the Cloud Console to access your Cloud Bigtable instances. + +[update-vm-scopes]: https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#changeserviceaccountandscopes diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index fd30aa8bbb962257c1ef5ac07e047fffca88c4bc..1102fb3c2dfb9ed71d286dc860a17b2079381eb0 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The Python API for TensorFlow's Bigtable integration. +"""The Python API for TensorFlow's Cloud Bigtable integration. TensorFlow has support for reading from and writing to Cloud Bigtable. To use -the Bigtable TensorFlow integration, first create a BigtableClient (which -configures your connection to Cloud Bigtable), and then open a Table. The Table -object then allows you to create numerous @{tf.data.Dataset}s to read data, or -write a @{tf.data.Dataset} object to the underlying Bigtable Table. +TensorFlow + Cloud Bigtable integration, first create a BigtableClient to +configure your connection to Cloud Bigtable, and then create a BigtableTable +object to allow you to create numerous @{tf.data.Dataset}s to read data, or +write a @{tf.data.Dataset} object to the underlying Cloud Bigtable table. -For background on Google Cloud Bigtable, see: https://cloud.google.com/bigtable. +For background on Cloud Bigtable, see: https://cloud.google.com/bigtable . """ from __future__ import absolute_import @@ -48,7 +48,7 @@ class BigtableClient(object): """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF. BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the - `table` method to open a Bigtable Table. + `table` method to open a Bigtable table. """ def __init__(self, @@ -94,7 +94,7 @@ class BigtableClient(object): project_id, instance_id, connection_pool_size, max_receive_message_size) def table(self, name, snapshot=None): - """Opens a table and returns a `BigtableTable` object. + """Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object. Args: name: A `tf.string` `tf.Tensor` name of the table to open. @@ -102,8 +102,8 @@ class BigtableClient(object): request the creation of a snapshot. (Note: currently unimplemented.) Returns: - A `BigtableTable` python object representing the operations available on - the table. + A `tf.contrib.bigtable.BigtableTable` Python object representing the + operations available on the table. """ # TODO(saeta): Implement snapshot functionality. table = gen_bigtable_ops.bigtable_table(self._resource, name) @@ -133,7 +133,8 @@ class BigtableTable(object): """Retrieves the values of columns for a dataset of keys. Example usage: - ``` + + ```python table = bigtable_client.table("my_table") key_dataset = table.get_keys_prefix("imagenet") images = key_dataset.apply(table.lookup_columns(("cf1", "image"), @@ -144,7 +145,8 @@ class BigtableTable(object): Alternatively, you can use keyword arguments to specify the columns to capture. Example (same as above, rewritten): - ``` + + ```python table = bigtable_client.table("my_table") key_dataset = table.get_keys_prefix("imagenet") images = key_dataset.apply(table.lookup_columns( @@ -152,15 +154,17 @@ class BigtableTable(object): training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) ``` - Note: certain kwargs keys are reserved, and thus some column families cannot - be identified using the kwargs syntax. Instead, please use the args syntax. - This list includes: + Note: certain `kwargs` keys are reserved, and thus, some column families + cannot be identified using the `kwargs` syntax. Instead, please use the + `args` syntax. This list includes: + - 'name' - This list can change at any time. + + Note: this list can change at any time. Args: *args: A list of tuples containing (column family, column name) pairs. - **kwargs: Column families and + **kwargs: Column families (keys) and column qualifiers (values). Returns: A function that can be passed to `tf.data.Dataset.apply` to retrieve the @@ -331,7 +335,7 @@ class BigtableTable(object): """Retrieves row (including values) from the Bigtable service at high speed. Rows with row-key prefixed by `prefix` will be retrieved. This method is - similar to `scan_prefix`, but by constrast performs multiple sub-scans in + similar to `scan_prefix`, but by contrast performs multiple sub-scans in parallel in order to achieve higher performance. Note: The dataset produced by this method is not deterministic! @@ -390,7 +394,7 @@ class BigtableTable(object): """Retrieves rows (including values) from the Bigtable service. Rows with row-keys between `start` and `end` will be retrieved. This method - is similar to `scan_range`, but by constrast performs multiple sub-scans in + is similar to `scan_range`, but by contrast performs multiple sub-scans in parallel in order to achieve higher performance. Note: The dataset produced by this method is not deterministic! @@ -712,7 +716,7 @@ class _BigtableScanDataset(dataset_ops.Dataset): class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset): - """_BigtableKeyRangeDataset returns key pairs from the Bigtable. + """_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table. """ def __init__(self, table, prefix, start, end): diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index ef0e80cd0997bc0e95cd0d150e87db144a2dde44..5fcb19a47aac492d49b0d8e99af5699bae2ad9f0 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -147,6 +147,7 @@ py_library( deps = [ ":distillation_loss", ":estimator_utils", + ":model", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees:model_ops_py", @@ -190,7 +191,7 @@ py_test( py_test( name = "estimator_test", - size = "medium", + size = "large", srcs = ["estimator_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 7eb429b636a5193a124dd9b0c020dae6cac910cb..194a5c8754cb0ab2db299e3fb5c998c0f27f8435 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -26,6 +26,7 @@ from __future__ import print_function import six from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks @@ -34,6 +35,7 @@ from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batc from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops @@ -62,27 +64,30 @@ def _add_hidden_layer_summary(value, tag): summary.histogram("%s_activation" % tag, value) -def _dnn_tree_combined_model_fn(features, - labels, - mode, - head, - dnn_hidden_units, - dnn_feature_columns, - tree_learner_config, - num_trees, - tree_examples_per_layer, - config=None, - dnn_optimizer="Adagrad", - dnn_activation_fn=nn.relu, - dnn_dropout=None, - dnn_input_layer_partitioner=None, - dnn_input_layer_to_tree=True, - dnn_steps_to_train=10000, - predict_with_tree_only=False, - tree_feature_columns=None, - tree_center_bias=False, - dnn_to_tree_distillation_param=None, - use_core_versions=False): +def _dnn_tree_combined_model_fn( + features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + dnn_to_tree_distillation_param=None, + use_core_versions=False, + output_type=model.ModelBuilderOutputType.MODEL_FN_OPS, + override_global_step_value=None): """DNN and GBDT combined model_fn. Args: @@ -131,6 +136,12 @@ def _dnn_tree_combined_model_fn(features, will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. Returns: A `ModelFnOps` object. @@ -156,6 +167,10 @@ def _dnn_tree_combined_model_fn(features, partitioned_variables.min_max_variable_partitioner( max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) + if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and + not use_core_versions): + raise ValueError("You must use core versions with Estimator Spec") + with variable_scope.variable_scope( dnn_parent_scope, values=tuple(six.itervalues(features)), @@ -235,7 +250,8 @@ def _dnn_tree_combined_model_fn(features, learner_config=tree_learner_config, feature_columns=tree_feature_columns, logits_dimension=head.logits_dimension, - features=tree_features) + features=tree_features, + use_core_columns=use_core_versions) with ops.name_scope("gbdt"): predictions_dict = gbdt_model.predict(mode) @@ -284,63 +300,98 @@ def _dnn_tree_combined_model_fn(features, del loss return control_flow_ops.no_op() - if use_core_versions: - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits) - dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - dnn_train_op).train_op + if tree_center_bias: + num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - tree_train_op = head.create_estimator_spec( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits) - tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - tree_train_op).train_op + if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_versions: + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits) + dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + dnn_train_op).train_op - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( + tree_train_op = head.create_estimator_spec( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits) + tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + tree_train_op).train_op + + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op + + # Add the hooks + model_fn_ops.training_hooks.extend([ + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value) + ]) + return model_fn_ops + + elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC: + fusion_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_no_train_op_fn, logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( + dnn_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( + logits=dnn_logits) + tree_spec = head.create_estimator_spec( features=tree_features, mode=mode, labels=labels, train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op - - if tree_center_bias: - num_trees += 1 - finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - - model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, - tree_train_op), - trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) - ]) + logits=tree_train_logits) - return model_fn_ops + training_hooks = [ + trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train, + tree_spec.train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value) + ] + fusion_spec = fusion_spec._replace(training_hooks=training_hooks + + list(fusion_spec.training_hooks)) + return fusion_spec class DNNBoostedTreeCombinedClassifier(estimator.Estimator): @@ -369,7 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -425,6 +477,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -455,7 +511,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedClassifier, self).__init__( model_fn=_model_fn, @@ -489,7 +546,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -545,6 +603,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -580,7 +642,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedRegressor, self).__init__( model_fn=_model_fn, @@ -615,7 +678,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -666,6 +730,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ def _model_fn(features, labels, mode, config): @@ -690,10 +758,109 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + + +class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator): + """Initializes a core version of DNNBoostedTreeCombinedEstimator. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + head: `Head` instance. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. + """ + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + head, + model_dir=None, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + dnn_to_tree_distillation_param=None): + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC, + use_core_versions=True, + override_global_step_value=None) + + super(CoreDNNBoostedTreeCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 9b7acfa664b0398216b5a7fb904960d8363929d6..839eedd3a87ccaa1faecd1966fe5907d682cac02 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -28,10 +28,11 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest - +from tensorflow.python.training import checkpoint_utils def _train_input_fn(): features = { @@ -156,5 +157,72 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.evaluate(input_fn=_eval_input_fn, steps=1) +class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + + def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[core_feature_column.numeric_column("x")]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished + self._assert_checkpoint(est.model_dir, global_step=14) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + + def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + tree_feature_columns=[]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 38fa8c38345f5006628b3b944d0c89d2df54f998..870ce2442bb5e98db7615c43054c9c827b8c88f0 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -22,8 +22,16 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator.canned import head as core_head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops +from tensorflow.python.ops.losses import losses as core_losses + + +# ================== Old estimator interface=================================== +# The estimators below were designed for old feature columns and old estimator +# interface. They can be used with new feature columns and losses by setting +# use_core_libs = True. class GradientBoostedDecisionTreeClassifier(estimator.Estimator): @@ -43,7 +51,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -77,6 +86,14 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. Raises: ValueError: If learner_config is not valid. @@ -117,6 +134,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -140,7 +158,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -174,6 +193,14 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -197,6 +224,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -222,7 +250,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -252,6 +281,14 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -266,6 +303,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -275,24 +313,23 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): class GradientBoostedDecisionTreeRanker(estimator.Estimator): """A ranking estimator using gradient boosted decision trees.""" - def __init__( - self, - learner_config, - examples_per_layer, - head, - ranking_model_pair_keys, - num_trees=None, - feature_columns=None, - weight_column_name=None, - model_dir=None, - config=None, - label_keys=None, - feature_engineering_fn=None, - logits_modifier_function=None, - center_bias=False, - use_core_libs=False, - output_leaf_index=False, - ): + def __init__(self, + learner_config, + examples_per_layer, + head, + ranking_model_pair_keys, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=False, + use_core_libs=False, + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeRanker instance. This is an estimator that can be trained off the pairwise data and can be @@ -332,7 +369,14 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree - + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. Raises: ValueError: If learner_config is not valid. """ @@ -351,14 +395,41 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, 'ranking_model_pair_keys': ranking_model_pair_keys, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) +# ================== New Estimator interface=================================== +# The estimators below use new core Estimator interface and must be used with +# new feature columns and heads. + +# For multiclass classification, use the following head since it uses loss +# that is twice differentiable. +def core_multiclass_head(n_classes): + """Core head for multiclass problems.""" + + def loss_fn(labels, logits): + result = losses.per_example_maxent_loss( + labels=labels, logits=logits, weights=None, num_classes=n_classes) + return result[0] + + # pylint:disable=protected-access + head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=n_classes, + loss_fn=loss_fn, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + # pylint:enable=protected-access + + return head_fn + class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): - """An estimator using gradient boosted decision trees.""" + """An estimator using gradient boosted decision trees. + + Useful for training with user specified `Head`. + """ def __init__(self, learner_config, @@ -374,6 +445,36 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): logits_modifier_function=None, center_bias=True, output_leaf_index=False): + """Initializes a core version of GradientBoostedDecisionTreeEstimator. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + """ def _model_fn(features, labels, mode, config): return model.model_builder( @@ -392,8 +493,92 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': True, 'output_leaf_index': output_leaf_index, + 'override_global_step_value': None }, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) super(CoreGradientBoostedDecisionTreeEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): + """A ranking estimator using gradient boosted decision trees.""" + + def __init__(self, + learner_config, + examples_per_layer, + head, + ranking_model_pair_keys, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + logits_modifier_function=None, + center_bias=False, + output_leaf_index=False): + """Initializes a GradientBoostedDecisionTreeRanker instance. + + This is an estimator that can be trained off the pairwise data and can be + used for inference on non-paired data. This is essentially LambdaMart. + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + ranking_model_pair_keys: Keys to distinguish between features + for left and right part of the training pairs for ranking. For example, + for an Example with features "a.f1" and "b.f1", the keys would be + ("a", "b"). + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + + Raises: + ValueError: If learner_config is not valid. + """ + + def _model_fn(features, labels, mode, config): + return model.ranking_model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': head, + 'n_classes': 2, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': True, + 'output_leaf_index': output_leaf_index, + 'ranking_model_pair_keys': ranking_model_pair_keys, + 'override_global_step_value': None + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeRanker, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index f787d3cdb81febded62e472549ec98250a0393ff..68d710d713770a3a4a623b9447bb6a6b93569cac 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -25,10 +25,12 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest +from tensorflow.python.training import checkpoint_utils def _train_input_fn(): @@ -37,6 +39,15 @@ def _train_input_fn(): return features, label +def _multiclass_train_input_fn(): + features = { + "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]]) + } + label = constant_op.constant( + [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32) + return features, label + + def _ranking_train_input_fn(): features = { "a.f1": constant_op.constant([[3.], [0.3], [1.]]), @@ -68,6 +79,10 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self._export_dir_base = tempfile.mkdtemp() + "export/" gfile.MkDir(self._export_dir_base) + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + def testFitAndEvaluateDontThrowException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 @@ -202,8 +217,128 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): model.evaluate(input_fn=_ranking_train_input_fn, steps=1) model.predict(input_fn=_infer_ranking_train_input_fn) + def testDoesNotOverrideGlobalSteps(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 2 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False) + + classifier.fit(input_fn=_train_input_fn, steps=15) + # When no override of global steps, 5 steps were used. + self._assert_checkpoint(classifier.model_dir, global_step=5) + + def testOverridesGlobalSteps(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 2 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False, + override_global_step_value=10000000) + + classifier.fit(input_fn=_train_input_fn, steps=15) + self._assert_checkpoint(classifier.model_dir, global_step=10000000) + + def testFitAndEvaluateMultiClassTreePerClassDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + def testFitAndEvaluateMultiClassDiagonalDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) -class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase): + def testFitAndEvaluateMultiClassFullDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.FULL_HESSIAN) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + +class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): def testTrainEvaluateInferDoesNotThrowError(self): head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( @@ -229,6 +364,115 @@ class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase): est.evaluate(input_fn=_eval_input_fn, steps=1) est.predict(input_fn=_eval_input_fn) + def testRankingDontThrowExceptionForForEstimator(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + est = estimator.CoreGradientBoostedDecisionTreeRanker( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[ + core_feature_column.numeric_column("f1"), + core_feature_column.numeric_column("f2") + ], + ranking_model_pair_keys=("a", "b")) + + # Train for a few steps. + est.train(input_fn=_ranking_train_input_fn, steps=1000) + est.evaluate(input_fn=_ranking_train_input_fn, steps=1) + est.predict(input_fn=_infer_ranking_train_input_fn) + + def testFitAndEvaluateMultiClassTreePerClasssDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testFitAndEvaluateMultiClassDiagonalDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testFitAndEvaluateMultiClassFullDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.FULL_HESSIAN) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 2fbe72951a559808ee2ee12a2efa07b1d857883a..04b46c3483fa25286078b88c2776b76e4f3c0bcf 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -58,7 +58,13 @@ def model_builder(features, * weight_column_name: The name of weight column. * center_bias: Whether a separate tree should be created for first fitting the bias. + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). Returns: A `ModelFnOps` object. @@ -74,6 +80,7 @@ def model_builder(features, use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] + override_global_step_value = params.get("override_global_step_value", None) if features is None: raise ValueError("At least one feature must be specified.") @@ -126,14 +133,16 @@ def model_builder(features, create_estimator_spec_op = getattr(head, "create_estimator_spec", None) + training_hooks = [] if num_trees: if center_bias: num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - training_hooks = [ + training_hooks.append( trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees) - ] + finalized_trees, + override_global_step_value)) if output_type == ModelBuilderOutputType.MODEL_FN_OPS: if use_core_libs and callable(create_estimator_spec_op): @@ -175,7 +184,12 @@ def model_builder(features, return model_fn_ops -def ranking_model_builder(features, labels, mode, params, config): +def ranking_model_builder(features, + labels, + mode, + params, + config, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Multi-machine batch gradient descent tree model for ranking. Args: @@ -198,7 +212,14 @@ def ranking_model_builder(features, labels, mode, params, config): for left and right part of the training pairs for ranking. For example, for an Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + Returns: A `ModelFnOps` object. @@ -215,6 +236,7 @@ def ranking_model_builder(features, labels, mode, params, config): logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] ranking_model_pair_keys = params["ranking_model_pair_keys"] + override_global_step_value = params.get("override_global_step_value", None) if features is None: raise ValueError("At least one feature must be specified.") @@ -326,31 +348,55 @@ def ranking_model_builder(features, labels, mode, params, config): return update_op create_estimator_spec_op = getattr(head, "create_estimator_spec", None) - if use_core_libs and callable(create_estimator_spec_op): - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) - if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: - model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ - gbdt_batch.LEAF_INDEX] + training_hooks = [] if num_trees: if center_bias: num_trees += 1 + finalized_trees, attempted_trees = ( gbdt_model_main.get_number_of_trees_tensor()) - model_fn_ops.training_hooks.append( + training_hooks.append( trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees)) + finalized_trees, + override_global_step_value)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + + model_fn_ops.training_hooks.extend(training_hooks) + return model_fn_ops + + elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: + assert callable(create_estimator_spec_op) + estimator_spec = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + return estimator_spec + return model_fn_ops diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py index 2e4151cac40f770e2bece70d752122eb7f34dd40..f137ada35524bf2467314f4a284ea35a82f06825 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py @@ -25,6 +25,7 @@ from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArg from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.summary_io import SummaryWriterCache @@ -150,12 +151,23 @@ class FeedFnHook(session_run_hook.SessionRunHook): class StopAfterNTrees(session_run_hook.SessionRunHook): """Stop training after building N full trees.""" - def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor): + def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor, + override_global_step_value=None): self._num_trees = n # num_attempted_trees_tensor and num_finalized_trees_tensor are both # tensors. self._num_attempted_trees_tensor = num_attempted_trees_tensor self._num_finalized_trees_tensor = num_finalized_trees_tensor + self._override_global_step_value = override_global_step_value + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step should be created.") + + if self._override_global_step_value is not None: + self._override_global_step_op = state_ops.assign( + self._global_step_tensor, self._override_global_step_value) def before_run(self, run_context): del run_context # unused by StopTrainingAfterNTrees. @@ -175,6 +187,9 @@ class StopAfterNTrees(session_run_hook.SessionRunHook): num_attempted_trees > 2 * self._num_trees): logging.info("Requesting stop since we have reached %d trees.", num_finalized_trees) + if self._override_global_step_value is not None: + logging.info("Overriding global steps value.") + run_context.session.run(self._override_global_step_op) run_context.request_stop() diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 0b28f81e7ca9a1228adc5bde19c429265e0aa9b8..1375fddf2bea1a8f856c35d756c38a8beb14a53f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -125,6 +125,8 @@ void QuantizeFeatures( auto flat_values = values_tensor.flat(); for (int64 instance = 0; instance < num_values; ++instance) { const float value = flat_values(instance); + CHECK(!buckets_vector.empty()) + << "Got empty buckets for feature " << feature_index; auto bucket_iter = std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value); if (bucket_iter == buckets_vector.end()) { @@ -241,6 +243,11 @@ class CreateQuantileAccumulatorOp : public OpKernel { // other exceptions. If one already exists, it unrefs the new one. const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); + // An epsilon value of zero could cause perfoamance issues and is therefore, + // disallowed. + OP_REQUIRES( + context, epsilon_ > 0, + errors::InvalidArgument("An epsilon value of zero is not allowed.")); auto result = new QuantileStreamResource(epsilon_, num_quantiles_, max_elements_, generate_quantiles_, stamp_token_t->scalar()()); diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 1bfeed306641111718984b2097512e5ec3fa8630..6d9a6ee5a0d05465459393c4339558f1ca38d417 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -372,12 +372,18 @@ class GrowTreeEnsembleOp : public OpKernel { return; } + // Get the max tree depth. + const Tensor* max_tree_depth_t; + OP_REQUIRES_OK(context, + context->input("max_tree_depth", &max_tree_depth_t)); + const int32 max_tree_depth = max_tree_depth_t->scalar()(); + // Update and retrieve the growable tree. // If the tree is fully built and dropout was applied, it also adjusts the // weights of dropped and the last tree. boosted_trees::trees::DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, - dropout_seed); + dropout_seed, max_tree_depth); // Split tree nodes. for (auto& split_entry : best_splits) { @@ -494,7 +500,8 @@ class GrowTreeEnsembleOp : public OpKernel { boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree( boosted_trees::models::DecisionTreeEnsembleResource* const ensemble_resource, - const float learning_rate, const uint64 dropout_seed) { + const float learning_rate, const uint64 dropout_seed, + const int32 max_tree_depth) { const auto num_trees = ensemble_resource->num_trees(); if (num_trees <= 0 || ensemble_resource->LastTreeMetadata()->is_finalized()) { @@ -506,8 +513,7 @@ class GrowTreeEnsembleOp : public OpKernel { tree_config->add_nodes()->mutable_leaf(); boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = ensemble_resource->LastTreeMetadata(); - tree_metadata->set_is_finalized( - learner_config_.constraints().max_tree_depth() <= 1); + tree_metadata->set_is_finalized(max_tree_depth <= 1); tree_metadata->set_num_tree_weight_updates(1); } else { // The growable tree is by definition the last tree in the ensemble. @@ -518,8 +524,7 @@ class GrowTreeEnsembleOp : public OpKernel { << num_trees - 1 << " of ensemble of " << num_trees << " trees."; // Update growable tree metadata. tree_metadata->set_num_layers_grown(new_num_layers); - tree_metadata->set_is_finalized( - new_num_layers >= learner_config_.constraints().max_tree_depth()); + tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth); } UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed); return ensemble_resource->LastTree(); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py index 1b7f59ea4218355a13f1df7264352bd68503bd19..5d4819b0f1cb598cfbe146f569aecd7883186339 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py @@ -131,6 +131,10 @@ class BaseSplitHandler(object): }, stamp_token, None) return control_flow_ops.group(update_1, *update_2[self]) + @abc.abstractmethod + def reset(self, stamp_token, next_stamp_token): + """Resets the state maintained by the handler.""" + @abc.abstractmethod def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index bf686237ff696dadad9713d26bf784d7442b80d0..efe29216c2a7d8aa985da54cdbb839b9e6f69078 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -202,3 +202,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # always return ready. are_splits_ready = constant_op.constant(True) return (are_splits_ready, partition_ids, gains, split_infos) + + def reset(self, stamp_token, next_stamp_token): + reset = self._stats_accumulator.flush(stamp_token, next_stamp_token) + return reset diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index df0bec1fe363e07bbff6b059e86076239bd605e9..2559fe9913f377ce38aa11dfa908cd25ec76dab4 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -79,6 +79,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -147,6 +148,11 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): num_quantiles=num_quantiles, name="QuantileAccumulator/{}".format(self._name)) + def reset(self, stamp_token, next_stamp_token): + reset_1 = self._stats_accumulator.flush(stamp_token, next_stamp_token) + reset_2 = self._quantile_accumulator.flush(stamp_token, next_stamp_token) + return control_flow_ops.group([reset_1, reset_2]) + class DenseSplitHandler(InequalitySplitHandler): """Computes stats and finds the best inequality splits on dense columns.""" @@ -264,6 +270,7 @@ class DenseSplitHandler(InequalitySplitHandler): self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, self._min_node_weight, self._loss_uses_sum_reduction)) + return are_splits_ready, partition_ids, gains, split_infos @@ -579,8 +586,10 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, example_partition_ids, feature_ids, gradients, hessians = ( control_flow_ops.cond( - math_ops.logical_and(are_buckets_ready, is_active[0]), - ready_inputs_fn, not_ready_inputs_fn)) + math_ops.logical_and( + math_ops.logical_and(are_buckets_ready, + array_ops.size(quantile_buckets) > 0), + is_active[0]), ready_inputs_fn, not_ready_inputs_fn)) return (quantile_values, quantile_weights, example_partition_ids, feature_ids, gradients, hessians) @@ -674,8 +683,10 @@ def sparse_make_stats_update( lambda: handler_not_active)) example_partition_ids, feature_ids, gradients, hessians = ( - control_flow_ops.cond(are_buckets_ready, quantiles_ready, - quantiles_not_ready)) + control_flow_ops.cond( + math_ops.logical_and(are_buckets_ready, + array_ops.size(quantile_buckets) > 0), + quantiles_ready, quantiles_not_ready)) return (quantile_indices, quantile_values, quantile_shape, quantile_weights, example_partition_ids, feature_ids, gradients, hessians) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index d59732cf92eb85e88732ac5a17dccf475ae5342f..5d82c4cae5dbe28c82fa8754a7c65db62a2e6814 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -1072,8 +1072,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes - gradients = array_ops.constant( - [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) + gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], + [4.0, -3]]) # 2x2 matrix for each instance hessian_0 = [[0.12, 0.02], [0.3, 0.11]] hessian_1 = [[0.07, -0.2], [-0.5, 0.2]] @@ -1167,8 +1167,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes - gradients = array_ops.constant( - [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) + gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], + [4.0, -3]]) # Each hessian is a diagonal from a full hessian matrix. hessian_0 = [0.12, 0.11] hessian_1 = [0.07, 0.2] @@ -1406,6 +1406,100 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testEmptyBuckets(self): + """Test that reproduces the case when quantile buckets were empty.""" + with self.test_session() as sess: + sparse_column = array_ops.sparse_placeholder(dtypes.float32) + + # We have two batches - at first, a sparse feature is empty. + empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) + empty_values = array_ops.constant([], dtype=dtypes.float32) + empty_sparse_column = sparse_tensor.SparseTensor(empty_indices, + empty_values, [4, 2]) + empty_sparse_column = empty_sparse_column.eval(session=sess) + + # For the second batch, the sparse feature is not empty. + non_empty_indices = array_ops.constant( + [[0, 0], [2, 1], [3, 2]], dtype=dtypes.int64, shape=[3, 2]) + non_empty_values = array_ops.constant( + [0.52, 0.3, 0.52], dtype=dtypes.float32) + non_empty_sparse_column = sparse_tensor.SparseTensor( + non_empty_indices, non_empty_values, [4, 2]) + non_empty_sparse_column = non_empty_sparse_column.eval(session=sess) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS) + resources.initialize_resources(resources.shared_resources()).run() + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + # First, calculate quantiles and try to update on an empty data for a + # feature. + are_splits_ready = ( + sess.run( + are_splits_ready, + feed_dict={sparse_column: empty_sparse_column})) + self.assertFalse(are_splits_ready) + + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + + # Now the feature in the second batch is not empty, but buckets + # calculated on the first batch are empty. + are_splits_ready2, partitions, gains, splits = ( + sess.run( + [are_splits_ready2, partitions, gains, splits], + feed_dict={sparse_column: non_empty_sparse_column})) + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + # Since the buckets were empty, we can't calculate the splits. + self.assertEqual(len(partitions), 0) + self.assertEqual(len(gains), 0) + self.assertEqual(len(splits), 0) + def testDegenerativeCase(self): with self.test_session() as sess: # One data example only, one leaf and thus one quantile bucket.The same diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index c120dd8a6c156ec9eb7ba0b6c552f5138bd21a16..f19e5116f5865777ab65e1add2777ac41105acc0 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -58,6 +58,8 @@ namespace quantiles { // Compute: O(n * log(1/eps * log(eps * n))). // Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the // entire dataset. +// An epsilon value of zero would make the algorithm extremely inefficent and +// therefore, is disallowed. template > class WeightedQuantilesStream { @@ -69,6 +71,9 @@ class WeightedQuantilesStream { explicit WeightedQuantilesStream(double eps, int64 max_elements) : eps_(eps), buffer_(1LL, 2LL), finalized_(false) { + // See the class documentation. An epsilon value of zero could cause + // perfoamance issues. + QCHECK(eps > 0) << "An epsilon value of zero is not allowed."; std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements); buffer_ = Buffer(block_size_, max_elements); summary_levels_.reserve(max_levels_); diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc index f63c199ad6146c23c22437ffe2287a77ee91ca44..22ac9edb72ea91ecef6fd1dff9f399b3c9020083 100644 --- a/tensorflow/contrib/boosted_trees/ops/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc @@ -56,6 +56,7 @@ REGISTER_OP("GrowTreeEnsemble") .Input("next_stamp_token: int64") .Input("learning_rate: float") .Input("dropout_seed: int64") + .Input("max_tree_depth: int32") .Input("partition_ids: num_handlers * int32") .Input("gains: num_handlers * float") .Input("splits: num_handlers * string") @@ -67,6 +68,8 @@ REGISTER_OP("GrowTreeEnsemble") TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_input)); // Dropout seed. TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_input)); + // Maximum tree depth. + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_input)); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 3e524efbeac74ff754d63cae92b3e194411cb2de..e39e1de8d1954c7f4dcab87d7727a64affa13c8c 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -296,7 +296,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here, tree is not finalized. - dropout_probability=0.5).SerializeToString() + dropout_probability=0.5) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -321,9 +321,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -443,7 +444,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here - tree is not finalized. - dropout_probability=0.5).SerializeToString() + dropout_probability=0.5) # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -472,9 +473,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -632,8 +634,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -657,9 +658,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -773,8 +775,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # All handlers have negative gain. @@ -794,9 +795,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the ensemble to be empty. @@ -839,8 +841,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -865,9 +866,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -946,8 +948,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # All handlers have negative gain. @@ -967,9 +968,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1048,9 +1050,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions], gains=[handler1_gains], splits=[handler1_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the ensemble to be empty as post-pruning will prune @@ -1094,8 +1097,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # Second handler has positive gain. @@ -1115,9 +1117,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1194,9 +1197,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions], gains=[handler1_gains], splits=[handler1_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the negative gain split of partition 1 to be pruned and the @@ -1335,7 +1339,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER, # Dropout will have no effect, since the tree will not be fully grown. - dropout_probability=1.0).SerializeToString() + dropout_probability=1.0) # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -1364,9 +1368,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -1543,7 +1548,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, - dropout_probability=1.0).SerializeToString() + dropout_probability=1.0) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1567,9 +1572,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -1669,7 +1675,6 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) learner_config.constraints.max_number_of_unique_feature_columns = 3 - learner_config = learner_config.SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) handler1_gains = np.array([7.62], dtype=np.float32) @@ -1692,9 +1697,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) _, serialized = session.run( diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 643d8d249867f6e97ff5affdf8c4c858907a871e..d0d1249bd6afc9cdbf6d88298c5024a4a54a5073 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -51,6 +51,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import device_setter + # Key names for prediction dict. ENSEMBLE_STAMP = "ensemble_stamp" PREDICTIONS = "predictions" @@ -353,6 +354,9 @@ class GradientBoostedDecisionTreeModel(object): self._gradient_shape = tensor_shape.scalar() self._hessian_shape = tensor_shape.scalar() else: + if center_bias: + raise ValueError("Center bias should be False for multiclass.") + self._gradient_shape = tensor_shape.TensorShape([logits_dimension]) if (learner_config.multi_class_strategy == learner_pb2.LearnerConfig.FULL_HESSIAN): @@ -380,6 +384,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._max_tree_depth = variables.Variable( + initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, @@ -520,9 +526,6 @@ class GradientBoostedDecisionTreeModel(object): if not input_deps: raise ValueError("No input tensors for prediction.") - if any(i.device != input_deps[0].device for i in input_deps): - raise ValueError("All input tensors should be on the same device.") - # Get most current model stamp. ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle) @@ -896,7 +899,7 @@ class GradientBoostedDecisionTreeModel(object): reset_ops = [] for handler in handlers: - reset_ops.append(handler.make_splits(stamp_token, next_stamp_token, 0)) + reset_ops.append(handler.reset(stamp_token, next_stamp_token)) if self._center_bias: reset_ops.append( bias_stats_accumulator.flush(stamp_token, next_stamp_token)) @@ -1054,7 +1057,8 @@ class GradientBoostedDecisionTreeModel(object): splits=split_info_list, learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, - center_bias=self._center_bias) + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) def _grow_ensemble_not_ready_fn(): # Don't grow the ensemble, just update the stamp. @@ -1068,7 +1072,8 @@ class GradientBoostedDecisionTreeModel(object): splits=[], learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, - center_bias=self._center_bias) + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) def _grow_ensemble_fn(): # Conditionally grow an ensemble depending on whether the splits @@ -1108,6 +1113,9 @@ class GradientBoostedDecisionTreeModel(object): def get_number_of_trees_tensor(self): return self._finalized_trees, self._attempted_trees + def get_max_tree_depth(self): + return self._max_tree_depth + def train(self, loss, predictions_dict, labels): """Updates the accumalator stats and grows the ensemble. diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index 1bfd27305d569668a0bd67d876e59eec082296b3..58fadffce32f9a8fec047d1e99f9f4eb5a710d91 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -85,7 +85,7 @@ Status BigQueryTableAccessor::New( int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, + std::shared_ptr http_request_factory, std::unique_ptr* accessor) { if (timestamp_millis <= 0) { return errors::InvalidArgument( @@ -94,29 +94,19 @@ Status BigQueryTableAccessor::New( const string& big_query_end_point = end_point.empty() ? kBigQueryEndPoint : end_point; if (auth_provider == nullptr && http_request_factory == nullptr) { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition)); - } else { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition, std::move(auth_provider), - std::move(http_request_factory))); + http_request_factory = std::make_shared(); + auto compute_engine_metadata_client = + std::make_shared(http_request_factory); + auth_provider = std::unique_ptr( + new GoogleAuthProvider(compute_engine_metadata_client)); } - return (*accessor)->ReadSchema(); -} -BigQueryTableAccessor::BigQueryTableAccessor( - const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, const string& end_point, - const std::vector& columns, const BigQueryTablePartition& partition) - : BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - end_point, columns, partition, - std::unique_ptr(new GoogleAuthProvider()), - std::unique_ptr( - new CurlHttpRequest::Factory())) { - row_buffer_.resize(row_buffer_size); + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition, std::move(auth_provider), + std::move(http_request_factory))); + + return (*accessor)->ReadSchema(); } BigQueryTableAccessor::BigQueryTableAccessor( @@ -124,7 +114,7 @@ BigQueryTableAccessor::BigQueryTableAccessor( int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory) + std::shared_ptr http_request_factory) : project_id_(project_id), dataset_id_(dataset_id), table_id_(table_id), diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h index b349063715c903c982cfe2fb116b6525e35ff63b..1af43a3e1070d466bb50019f12b22a060c1e6ab1 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -109,24 +109,17 @@ class BigQueryTableAccessor { const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, + std::shared_ptr http_request_factory, std::unique_ptr* accessor); /// \brief Constructs an object for a given table and partition. - BigQueryTableAccessor(const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const string& end_point, - const std::vector& columns, - const BigQueryTablePartition& partition); - - /// Used for unit testing. BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory); + std::shared_ptr http_request_factory); /// \brief Parses column values for a given row. Status ParseColumnValues(const Json::Value& value, @@ -199,7 +192,7 @@ class BigQueryTableAccessor { SchemaNode schema_root_; std::unique_ptr auth_provider_; - std::unique_ptr http_request_factory_; + std::shared_ptr http_request_factory_; TF_DISALLOW_COPY_AND_ASSIGN(BigQueryTableAccessor); }; diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 8f521ffee4d31e090c13bac98290656d6e1d330e..1ab150d74ac00c5f9acf3c9399880708b2f62b1e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -148,6 +148,9 @@ class TPUClusterResolver(ClusterResolver): else: tpu = self._envVarFallback() + if tpu is None: + raise ValueError('Please provide a TPU Name to connect to.') + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name self._credentials = credentials @@ -259,11 +262,11 @@ class TPUClusterResolver(ClusterResolver): if 'state' in response and response['state'] != 'READY': raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % - (self._tpu, response['state'])) + (compat.as_text(self._tpu), response['state'])) if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, - response['health'])) + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % + (compat.as_text(self._tpu), response['health'])) if 'networkEndpoints' in response: worker_list = [ diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 6c93487e0d6d5caf576f1af80c36e4df895e6afa..f6c928e2be62e7292c6feaa3bb26fd463320158b 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -471,7 +471,6 @@ if (tensorflow_ENABLE_GPU) ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake index 45a0096085cc2a6332c82e1ea284812acdd45152..33bb31148d2e5b7ca177d7c30b7781e8f620c3cb 100644 --- a/tensorflow/contrib/cmake/external/eigen.cmake +++ b/tensorflow/contrib/cmake/external/eigen.cmake @@ -19,6 +19,12 @@ # build_file = "eigen.BUILD", #) +option(eigen_PATCH_FILE "Patch file to apply to eigen" OFF) +set(eigen_PATCH_COMMAND "") +if(eigen_PATCH_FILE) + set(eigen_PATCH_COMMAND PATCH_COMMAND patch -p0 -i "${eigen_PATCH_FILE}") +endif(eigen_PATCH_FILE) + include (ExternalProject) # We parse the current Eigen version and archive hash from the bazel configuration @@ -45,6 +51,7 @@ ExternalProject_Add(eigen URL ${eigen_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" INSTALL_DIR "${eigen_INSTALL}" + ${eigen_PATCH_COMMAND} CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF diff --git a/tensorflow/contrib/cmake/external/highwayhash.cmake b/tensorflow/contrib/cmake/external/highwayhash.cmake index a6e8a38d8c2ee3deb5453c264e0c5eb23248301f..7d260b85f21e7e56e153daf550c81155e4b68777 100644 --- a/tensorflow/contrib/cmake/external/highwayhash.cmake +++ b/tensorflow/contrib/cmake/external/highwayhash.cmake @@ -20,14 +20,6 @@ set(highwayhash_TAG be5edafc2e1a455768e260ccd68ae7317b6690ee) set(highwayhash_BUILD ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/src/highwayhash) set(highwayhash_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/install) -# put highwayhash includes in the directory where they are expected -add_custom_target(highwayhash_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash - DEPENDS highwayhash) - -add_custom_target(highwayhash_copy_headers_to_destination - DEPENDS highwayhash_create_destination_dir) - if(WIN32) set(highwayhash_HEADERS "${highwayhash_BUILD}/highwayhash/*.h") set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/highwayhash.lib) @@ -36,6 +28,20 @@ else() set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/libhighwayhash.a) endif() +set(highwayhash_HEADERS + "${highwayhash_INSTALL}/include/code_annotation.h" + "${highwayhash_INSTALL}/include/highway_tree_hash.h" + "${highwayhash_INSTALL}/include/scalar_highway_tree_hash.h" + "${highwayhash_INSTALL}/include/scalar_sip_tree_hash.h" + "${highwayhash_INSTALL}/include/sip_hash.h" + "${highwayhash_INSTALL}/include/sip_tree_hash.h" + "${highwayhash_INSTALL}/include/sse41_highway_tree_hash.h" + "${highwayhash_INSTALL}/include/state_helpers.h" + "${highwayhash_INSTALL}/include/types.h" + "${highwayhash_INSTALL}/include/vec.h" + "${highwayhash_INSTALL}/include/vec2.h" +) + ExternalProject_Add(highwayhash PREFIX highwayhash GIT_REPOSITORY ${highwayhash_URL} @@ -50,5 +56,15 @@ ExternalProject_Add(highwayhash -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${highwayhash_INSTALL}) -add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory ${highwayhash_INSTALL}/include/ ${highwayhash_INCLUDE_DIR}/highwayhash) +# put highwayhash includes in the directory where they are expected +add_custom_target(highwayhash_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash + DEPENDS highwayhash) + +add_custom_target(highwayhash_copy_headers_to_destination + DEPENDS highwayhash_create_destination_dir) + +foreach(header_file ${highwayhash_HEADERS}) + add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${highwayhash_INCLUDE_DIR}/highwayhash/) +endforeach() diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index eba3bcfc79efe87d0a45c979c5accfa1b6511ed0..1d638e64023c7e2706d8d97ff8679677b6cd289d 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -20,14 +20,6 @@ set(nsync_TAG 1.20.0) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) -# put nsync includes in the directory where they are expected -add_custom_target(nsync_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR} - DEPENDS nsync) - -add_custom_target(nsync_copy_headers_to_destination - DEPENDS nsync_create_destination_dir) - if(WIN32) set(nsync_HEADERS "${nsync_BUILD}/public/*.h") set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync.lib) @@ -49,7 +41,35 @@ ExternalProject_Add(nsync -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${nsync_INSTALL} - -DNSYNC_LANGUAGE:STRING=c++11) + -DNSYNC_LANGUAGE:STRING=c++11) + +set(nsync_HEADERS + "${nsync_INSTALL}/include/nsync.h" + "${nsync_INSTALL}/include/nsync_atomic.h" + "${nsync_INSTALL}/include/nsync_counter.h" + "${nsync_INSTALL}/include/nsync_cpp.h" + "${nsync_INSTALL}/include/nsync_cv.h" + "${nsync_INSTALL}/include/nsync_debug.h" + "${nsync_INSTALL}/include/nsync_mu.h" + "${nsync_INSTALL}/include/nsync_mu_wait.h" + "${nsync_INSTALL}/include/nsync_note.h" + "${nsync_INSTALL}/include/nsync_once.h" + "${nsync_INSTALL}/include/nsync_time.h" + "${nsync_INSTALL}/include/nsync_time_internal.h" + "${nsync_INSTALL}/include/nsync_waiter.h" +) + +# put nsync includes in the directory where they are expected +add_custom_target(nsync_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR} + DEPENDS nsync) + +add_custom_target(nsync_copy_headers_to_destination + DEPENDS nsync_create_destination_dir) + +foreach(header_file ${nsync_HEADERS}) + add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${nsync_INCLUDE_DIR}/) +endforeach() + -add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory ${nsync_INSTALL}/include/ ${nsync_INCLUDE_DIR}/) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 75e00f32675df1b7e523bc7e8bb44fa584b79347..9045290679b87a201df8b930df6ff9a4ec106dcf 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -115,7 +115,6 @@ tensorflow/contrib/coder tensorflow/contrib/coder/kernels tensorflow/contrib/coder/ops tensorflow/contrib/coder/python -tensorflow/contrib/coder/python/layers tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 32b185f07b6ba836ffb47e85beff6fb2481fdc3e..6d86daf5f174a3238ab92e5bba6085c904766766 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -198,7 +198,7 @@ function(add_python_module MODULE_NAME) # so we currently add explicit commands to include those files # later on in this script. if (NOT "${script}" MATCHES "_test\.py$") - add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD + add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/${script} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/${script}) endif() endforeach() @@ -297,7 +297,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ) target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE tf_protos_cc - tf_python_protos_cc + tf_python_protos_cc ${tensorflow_EXTERNAL_LIBRARIES} ) @@ -549,15 +549,15 @@ if(WIN32) ${NUMPY_INCLUDE_DIR} ) #target_link_libraries(pywrap_tensorflow_internal_static - # tf_protos_cc - # tf_python_protos_cc + # tf_protos_cc + # tf_python_protos_cc #) add_dependencies(pywrap_tensorflow_internal_static tf_protos_cc tf_python_protos_cc) set(pywrap_tensorflow_internal_static_dependencies $ $ $ - ${nsync_STATIC_LIBRARIES} + ${nsync_STATIC_LIBRARIES} ) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -737,7 +737,7 @@ endif() ######################################################## # Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_init_files.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -763,57 +763,40 @@ file(WRITE "${api_init_list_file}" "${api_init_files}") # recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue. # To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem, # and should be removed if the path issue can be resolved. +# UPDATE: Below block appears to handle multiple items in PATH correctly, but risks command line limits if PATH is large. +# If you have issues, try `set(PY_RUNTIME_ENV "PATH=${mkl_BIN_DIRS}")` instead. ### -if (tensorflow_ENABLE_MKL_SUPPORT) +set(PY_RUNTIME_ENV "") +if(tensorflow_ENABLE_MKL_SUPPORT) # add mkl dist dlls to system path for python - # TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths, - # so we have to specify only one path in it to work around the issue. We need this if/else - # to protect overwriting CUDA environments - set(PY_RUNTIME_ENV ${mkl_BIN_DIRS}) - add_custom_command( - OUTPUT ${api_init_files} - DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - - # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" - "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" - "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" - "--package=tensorflow.python" - "--apiname=tensorflow" - "${api_init_list_file}" - - COMMENT "Generating __init__.py files for Python API." - WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" - VERBATIM - ) -else (tensorflow_ENABLE_MKL_SUPPORT) - add_custom_command( - OUTPUT ${api_init_files} - DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - - # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" - "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" - "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" - "--package=tensorflow.python" - "--apiname=tensorflow" - "${api_init_list_file}" - - COMMENT "Generating __init__.py files for Python API." - WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" - ) -endif (tensorflow_ENABLE_MKL_SUPPORT) + file(TO_CMAKE_PATH "$ENV{PATH}" PY_RUNTIME_ENV) + set(PY_RUNTIME_ENV ${mkl_BIN_DIRS} ${PY_RUNTIME_ENV}) + file(TO_NATIVE_PATH "${PY_RUNTIME_ENV}" PY_RUNTIME_ENV) + set(PY_RUNTIME_ENV "PATH=${PY_RUNTIME_ENV}") +endif(tensorflow_ENABLE_MKL_SUPPORT) + +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + VERBATIM +) add_custom_target(tf_python_api SOURCES ${api_init_files}) add_dependencies(tf_python_api tf_python_ops) @@ -848,12 +831,12 @@ add_custom_command( DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" "--package=tensorflow.python.estimator" "--apiname=estimator" - "--output_package=tensorflow.python.estimator.api" + "--output_package=tensorflow.python.estimator.api" "${estimator_api_init_list_file}" COMMENT "Generating __init__.py files for Python API." diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index b2330c4e340d531f70234de812ab6f6b2e5c1160..2c878c17167c662d10a8c7dabf41687efdbf65d8 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -122,6 +122,17 @@ function(AddPythonTests) endforeach() endfunction(AddPythonTests) +# +# ensure that every element is an existing file +# +function(CheckExists TYPE SOURCES) + foreach(source ${SOURCES}) + if(NOT EXISTS ${source}) + message(SEND_ERROR "${TYPE} not found: ${source}") + endif() + endforeach(source) +endfunction(CheckExists) + if (tensorflow_BUILD_PYTHON_TESTS) # # python tests. This assumes that the tensorflow wheel is @@ -145,7 +156,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" - "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" @@ -198,7 +208,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" # requires scipy - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py" # Takes very long to run without sharding (defined in bazel build file). @@ -256,10 +265,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Flaky because of local cluster creation. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" - "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" + "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" @@ -329,6 +337,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 ) endif() + CheckExists(${tf_test_src_py_exclude}) list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) AddPythonTests( @@ -480,6 +489,7 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/cc/saved_model/*_test.cc" ) + CheckExists(${tf_test_src_simple_exclude}) list(REMOVE_ITEM tf_test_src_simple ${tf_test_src_simple_exclude} ${tf_cc_saved_model_test_srcs} @@ -494,6 +504,7 @@ if (tensorflow_BUILD_CC_TESTS) ${tf_core_profiler_test_srcs} ) + CheckExists(${tf_src_testlib}) set(tf_test_lib tf_test_lib) add_library(${tf_test_lib} STATIC ${tf_src_testlib}) diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index a2c6e413039ee3b5af3cb53d1af3325037536d36..855c824ead2f7de4c37db2d2a3648a9ee00fb9e9 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -1,5 +1,5 @@ # Description: -# Contains tools related to data compression. +# Contains ops related to data compression. package(default_visibility = [ "//learning/brain:__subpackages__", @@ -168,7 +168,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":coder_ops_py", - ":entropybottleneck_py", ], ) @@ -205,44 +204,3 @@ tf_py_test( ], main = "python/ops/coder_ops_test.py", ) - -py_library( - name = "entropybottleneck_py", - srcs = [ - "python/layers/entropybottleneck.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":coder_ops_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:functional_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/keras:engine", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "entropybottleneck_py_test", - srcs = [ - "python/layers/entropybottleneck_test.py", - ], - additional_deps = [ - ":entropybottleneck_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:variables", - "//tensorflow/python:training", - ], - main = "python/layers/entropybottleneck_test.py", -) diff --git a/tensorflow/contrib/coder/README.md b/tensorflow/contrib/coder/README.md deleted file mode 100644 index c6c379c458893551b765327c0c1cbfff7f24f9c3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# Entropy coder - -This module contains range encoder and range decoder which can encode integer -data into string with cumulative distribution functions (CDF). - -## Data and CDF values - -The data to be encoded should be non-negative integers in half-open interval -`[0, m)`. Then a CDF is represented as an integral vector of length `m + 1` -where `CDF(i) = f(Pr(X < i) * 2^precision)` for i = 0,1,...,m, and `precision` -is an attribute in range `0 < precision <= 16`. The function `f` maps real -values into integers, e.g., round or floor. It is important that to encode a -number `i`, `CDF(i + 1) - CDF(i)` cannot be zero. - -Note that we used `Pr(X < i)` not `Pr(X <= i)`, and therefore CDF(0) = 0 always. - -## RangeEncode: data shapes and CDF shapes - -For each data element, its CDF has to be provided. Therefore if the shape of CDF -should be `data.shape + (m + 1,)` in NumPy-like notation. For example, if `data` -is a 2-D tensor of shape (10, 10) and its elements are in `[0, 64)`, then the -CDF tensor should have shape (10, 10, 65). - -This may make CDF tensor too large, and in many applications all data elements -may have the same probability distribution. To handle this, `RangeEncode` -supports limited broadcasting CDF into data. Broadcasting is limited in the -following sense: - -- All CDF axes but the last one is broadcasted into data but not the other way - around, -- The number of CDF axes does not extend, i.e., `CDF.ndim == data.ndim + 1`. - -In the previous example where data has shape (10, 10), the following are -acceptable CDF shapes: - -- (10, 10, 65) -- (1, 10, 65) -- (10, 1, 65) -- (1, 1, 65) - -## RangeDecode - -`RangeEncode` encodes neither data shape nor termination character. Therefore -the decoder should know how many characters are encoded into the string, and -`RangeDecode` takes the encoded data shape as the second argument. The same -shape restrictions as `RangeEncode` inputs apply here. - -## Example - -```python -data = tf.random_uniform((128, 128), 0, 10, dtype=tf.int32) - -histogram = tf.bincount(data, minlength=10, maxlength=10) -cdf = tf.cumsum(histogram, exclusive=False) -# CDF should have length m + 1. -cdf = tf.pad(cdf, [[1, 0]]) -# CDF axis count must be one more than data. -cdf = tf.reshape(cdf, [1, 1, -1]) - -# Note that data has 2^14 elements, and therefore the sum of CDF is 2^14. -data = tf.cast(data, tf.int16) -encoded = coder.range_encode(data, cdf, precision=14) -decoded = coder.range_decode(encoded, tf.shape(data), cdf, precision=14) - -# data and decoded should be the same. -sess = tf.Session() -x, y = sess.run((data, decoded)) -assert np.all(x == y) -``` - -## Authors -Sung Jin Hwang (github: [ssjhv](https://github.com/ssjhv)) and Nick Johnston -(github: [nmjohn](https://github.com/nmjohn)) diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py index 99b8ac7595ec632b2918e6b7ca22c06dd7f0a8b3..8897312046c63c42d85e7fba5b62d2ed908dd6e9 100644 --- a/tensorflow/contrib/coder/__init__.py +++ b/tensorflow/contrib/coder/__init__.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Data compression tools.""" +"""Data compression ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import -from tensorflow.contrib.coder.python.layers.entropybottleneck import * from tensorflow.contrib.coder.python.ops.coder_ops import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py deleted file mode 100644 index 0c997bd4fdfa4233117c9fec2c4397301b1c8cb9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ /dev/null @@ -1,697 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Entropy bottleneck layer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.coder.python.ops import coder_ops - -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine import base_layer -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.summary import summary - - -class EntropyBottleneck(base_layer.Layer): - """Entropy bottleneck layer. - - This layer can be used to model the entropy (the amount of information - conveyed) of the tensor passing through it. During training, this can be used - to impose a (soft) entropy constraint on its activations, limiting the amount - of information flowing through the layer. Note that this is distinct from - other types of bottlenecks, which reduce the dimensionality of the space, for - example. Dimensionality reduction does not limit the amount of information, - and does not enable efficient data compression per se. - - After training, this layer can be used to compress any input tensor to a - string, which may be written to a file, and to decompress a file which it - previously generated back to a reconstructed tensor (possibly on a different - machine having access to the same model checkpoint). The entropies estimated - during training or evaluation are approximately equal to the average length of - the strings in bits. - - The layer implements a flexible probability density model to estimate entropy, - which is described in the appendix of the paper (please cite the paper if you - use this code for scientific work): - - "Variational image compression with a scale hyperprior" - - Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston - - https://arxiv.org/abs/1802.01436 - - The layer assumes that the input tensor is at least 2D, with a batch dimension - at the beginning and a channel dimension as specified by `data_format`. The - layer trains an independent probability density model for each channel, but - assumes that across all other dimensions, the inputs are i.i.d. (independent - and identically distributed). Because the entropy (and hence, average - codelength) is a function of the densities, this assumption may have a direct - effect on the compression performance. - - Because data compression always involves discretization, the outputs of the - layer are generally only approximations of its inputs. During training, - discretization is modeled using additive uniform noise to ensure - differentiability. The entropies computed during training are differential - entropies. During evaluation, the data is actually quantized, and the - entropies are discrete (Shannon entropies). To make sure the approximated - tensor values are good enough for practical purposes, the training phase must - be used to balance the quality of the approximation with the entropy, by - adding an entropy term to the training loss, as in the following example. - - Here, we use the entropy bottleneck to compress the latent representation of - an autoencoder. The data vectors `x` in this case are 4D tensors in - `'channels_last'` format (for example, 16x16 pixel grayscale images). - - The layer always produces exactly one auxiliary loss and one update op which - are only significant for compression and decompression. To use the compression - feature, the auxiliary loss must be minimized during or after training. After - that, the update op must be executed at least once. Here, we simply attach - them to the main training step. - - Training: - ``` - # Build autoencoder. - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - entropy_bottleneck = EntropyBottleneck() - y_, likelihoods = entropy_bottleneck(y, training=True) - x_ = backward_transform(y_) - - # Information content (= predicted codelength) in bits of each batch element - # (note that taking the natural logarithm and dividing by `log(2)` is - # equivalent to taking base-2 logarithms): - bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) - - # Squared difference of each batch element: - squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) - - # The loss is a weighted sum of mean squared error and entropy (average - # information content), where the weight controls the trade-off between - # approximation error and entropy. - main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) - - # Minimize loss and auxiliary loss, and execute update op. - main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) - main_step = optimizer.minimize(main_loss) - # 1e-2 is a good starting point for the learning rate of the auxiliary loss, - # assuming Adam is used. - aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2) - aux_step = optimizer.minimize(entropy_bottleneck.losses[0]) - step = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) - ``` - - Evaluation: - ``` - # Build autoencoder. - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - y_, likelihoods = EntropyBottleneck()(y, training=False) - x_ = backward_transform(y_) - - # Information content (= predicted codelength) in bits of each batch element: - bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) - - # Squared difference of each batch element: - squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) - - # The loss is a weighted sum of mean squared error and entropy (average - # information content), where the weight controls the trade-off between - # approximation error and entropy. - loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) - ``` - - To be able to compress the bottleneck tensor and decompress it in a different - session, or on a different machine, you need three items: - - The compressed representations stored as strings. - - The shape of the bottleneck for these string representations as a `Tensor`, - as well as the number of channels of the bottleneck at graph construction - time. - - The checkpoint of the trained model that was used for compression. Note: - It is crucial that the auxiliary loss produced by this layer is minimized - during or after training, and that the update op is run after training and - minimization of the auxiliary loss, but *before* the checkpoint is saved. - - Compression: - ``` - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - strings = EntropyBottleneck().compress(y) - shape = tf.shape(y)[1:] - ``` - - Decompression: - ``` - strings = tf.placeholder(tf.string, shape=[None]) - shape = tf.placeholder(tf.int32, shape=[3]) - entropy_bottleneck = EntropyBottleneck(dtype=tf.float32) - y_ = entropy_bottleneck.decompress(strings, shape, channels=5) - x_ = backward_transform(y_) - ``` - Here, we assumed that the tensor produced by the forward transform has 5 - channels. - - The above four use cases can also be implemented within the same session (i.e. - on the same `EntropyBottleneck` instance), for testing purposes, etc., by - calling the object more than once. - - Arguments: - init_scale: Float. A scaling factor determining the initial width of the - probability densities. This should be chosen big enough so that the - range of values of the layer inputs roughly falls within the interval - [`-init_scale`, `init_scale`] at the beginning of training. - filters: An iterable of ints, giving the number of filters at each layer of - the density model. Generally, the more filters and layers, the more - expressive is the density model in terms of modeling more complicated - distributions of the layer inputs. For details, refer to the paper - referenced above. The default is `[3, 3, 3]`, which should be sufficient - for most practical purposes. - tail_mass: Float, between 0 and 1. The bottleneck layer automatically - determines the range of input values that should be represented based on - their frequency of occurrence. Values occurring in the tails of the - distributions will be clipped to that range during compression. - `tail_mass` determines the amount of probability mass in the tails which - is cut off in the worst case. For example, the default value of `1e-9` - means that at most 1 in a billion input samples will be clipped to the - range. - optimize_integer_offset: Boolean. Typically, the input values of this layer - are floats, which means that quantization during evaluation can be - performed with an arbitrary offset. By default, the layer determines that - offset automatically. In special situations, such as when it is known that - the layer will receive only full integer values during evaluation, it can - be desirable to set this argument to `False` instead, in order to always - quantize to full integer values. - likelihood_bound: Float. If positive, the returned likelihood values are - ensured to be greater than or equal to this value. This prevents very - large gradients with a typical entropy loss (defaults to 1e-9). - range_coder_precision: Integer, between 1 and 16. The precision of the range - coder used for compression and decompression. This trades off computation - speed with compression efficiency, where 16 is the slowest but most - efficient setting. Choosing lower values may increase the average - codelength slightly compared to the estimated entropies. - data_format: Either `'channels_first'` or `'channels_last'` (default). - trainable: Boolean. Whether the layer should be trained. - name: String. The name of the layer. - dtype: Default dtype of the layer's parameters (default of `None` means use - the type of the first input). - - Read-only properties: - init_scale: See above. - filters: See above. - tail_mass: See above. - optimize_integer_offset: See above. - likelihood_bound: See above. - range_coder_precision: See above. - data_format: See above. - name: String. See above. - dtype: See above. - trainable_variables: List of trainable variables. - non_trainable_variables: List of non-trainable variables. - variables: List of all variables of this layer, trainable and non-trainable. - updates: List of update ops of this layer. Always contains exactly one - update op, which must be run once after the last training step, before - `compress` or `decompress` is used. - losses: List of losses added by this layer. Always contains exactly one - auxiliary loss, which must be added to the training loss. - - Mutable properties: - trainable: Boolean. Whether the layer should be trained. - input_spec: Optional `InputSpec` object specifying the constraints on inputs - that can be accepted by the layer. - """ - - def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9, - optimize_integer_offset=True, likelihood_bound=1e-9, - range_coder_precision=16, data_format="channels_last", **kwargs): - super(EntropyBottleneck, self).__init__(**kwargs) - self._init_scale = float(init_scale) - self._filters = tuple(int(f) for f in filters) - self._tail_mass = float(tail_mass) - if not 0 < self.tail_mass < 1: - raise ValueError( - "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass)) - self._optimize_integer_offset = bool(optimize_integer_offset) - self._likelihood_bound = float(likelihood_bound) - self._range_coder_precision = int(range_coder_precision) - self._data_format = data_format - self._channel_axis(2) # trigger ValueError early - self.input_spec = base_layer.InputSpec(min_ndim=2) - - @property - def init_scale(self): - return self._init_scale - - @property - def filters(self): - return self._filters - - @property - def tail_mass(self): - return self._tail_mass - - @property - def optimize_integer_offset(self): - return self._optimize_integer_offset - - @property - def likelihood_bound(self): - return self._likelihood_bound - - @property - def range_coder_precision(self): - return self._range_coder_precision - - @property - def data_format(self): - return self._data_format - - def _channel_axis(self, ndim): - try: - return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format] - except KeyError: - raise ValueError("Unsupported `data_format` for {} layer: {}.".format( - self.__class__.__name__, self.data_format)) - - def _logits_cumulative(self, inputs, stop_gradient): - """Evaluate logits of the cumulative densities. - - Args: - inputs: The values at which to evaluate the cumulative densities, expected - to be a `Tensor` of shape `(channels, 1, batch)`. - stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so - that the gradient of the output with respect to the density model - parameters is disconnected (the gradient with respect to `inputs` is - left untouched). - - Returns: - A `Tensor` of the same shape as `inputs`, containing the logits of the - cumulative densities evaluated at the given inputs. - """ - logits = inputs - - for i in range(len(self.filters) + 1): - matrix = self._matrices[i] - if stop_gradient: - matrix = array_ops.stop_gradient(matrix) - logits = math_ops.matmul(matrix, logits) - - bias = self._biases[i] - if stop_gradient: - bias = array_ops.stop_gradient(bias) - logits += bias - - if i < len(self._factors): - factor = self._factors[i] - if stop_gradient: - factor = array_ops.stop_gradient(factor) - logits += factor * math_ops.tanh(logits) - - return logits - - def build(self, input_shape): - """Builds the layer. - - Creates the variables for the network modeling the densities, creates the - auxiliary loss estimating the median and tail quantiles of the densities, - and then uses that to create the probability mass functions and the update - op that produces the discrete cumulative density functions used by the range - coder. - - Args: - input_shape: Shape of the input tensor, used to get the number of - channels. - - Raises: - ValueError: if `input_shape` doesn't specify the length of the channel - dimension. - """ - input_shape = tensor_shape.TensorShape(input_shape) - channel_axis = self._channel_axis(input_shape.ndims) - channels = input_shape[channel_axis].value - if channels is None: - raise ValueError("The channel dimension of the inputs must be defined.") - self.input_spec = base_layer.InputSpec( - ndim=input_shape.ndims, axes={channel_axis: channels}) - filters = (1,) + self.filters + (1,) - scale = self.init_scale ** (1 / (len(self.filters) + 1)) - - # Create variables. - self._matrices = [] - self._biases = [] - self._factors = [] - for i in range(len(self.filters) + 1): - init = np.log(np.expm1(1 / scale / filters[i + 1])) - matrix = self.add_variable( - "matrix_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], filters[i]), - initializer=init_ops.Constant(init)) - matrix = nn.softplus(matrix) - self._matrices.append(matrix) - - bias = self.add_variable( - "bias_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], 1), - initializer=init_ops.RandomUniform(-.5, .5)) - self._biases.append(bias) - - if i < len(self.filters): - factor = self.add_variable( - "factor_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], 1), - initializer=init_ops.Zeros()) - factor = math_ops.tanh(factor) - self._factors.append(factor) - - # To figure out what range of the densities to sample, we need to compute - # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we - # can't take inverses of the cumulative directly, we make it an optimization - # problem: - # `quantiles = argmin(|logit(cumulative) - target|)` - # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. - # Taking the logit (inverse of sigmoid) of the cumulative makes the - # representation of the right target more numerically stable. - - # Numerically stable way of computing logits of `tail_mass / 2` - # and `1 - tail_mass / 2`. - target = np.log(2 / self.tail_mass - 1) - # Compute lower and upper tail quantile as well as median. - target = constant_op.constant([-target, 0, target], dtype=self.dtype) - - def quantiles_initializer(shape, dtype=None, partition_info=None): - del partition_info # unused - assert tuple(shape[1:]) == (1, 3) - init = constant_op.constant( - [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype) - return array_ops.tile(init, (shape[0], 1, 1)) - - quantiles = self.add_variable( - "quantiles", shape=(channels, 1, 3), dtype=self.dtype, - initializer=quantiles_initializer) - logits = self._logits_cumulative(quantiles, stop_gradient=True) - loss = math_ops.reduce_sum(abs(logits - target)) - self.add_loss(loss, inputs=None) - - # Save medians for `call`, `compress`, and `decompress`. - self._medians = quantiles[:, :, 1:2] - if not self.optimize_integer_offset: - self._medians = math_ops.round(self._medians) - - # Largest distance observed between lower tail quantile and median, - # or between median and upper tail quantile. - minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1]) - maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians) - minmax = math_ops.maximum(minima, maxima) - minmax = math_ops.ceil(minmax) - minmax = math_ops.maximum(minmax, 1) - - # Sample the density up to `minmax` around the median. - samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype) - samples += self._medians - - half = constant_op.constant(.5, dtype=self.dtype) - # We strip the sigmoid from the end here, so we can use the special rule - # below to only compute differences in the left tail of the sigmoid. - # This increases numerical stability (see explanation in `call`). - lower = self._logits_cumulative(samples - half, stop_gradient=True) - upper = self._logits_cumulative(samples + half, stop_gradient=True) - # Flip signs if we can move more towards the left tail of the sigmoid. - sign = -math_ops.sign(math_ops.add_n([lower, upper])) - pmf = abs(math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) - # Add tail masses to first and last bin of pmf, as we clip values for - # compression, meaning that out-of-range values get mapped to these bins. - pmf = array_ops.concat([ - math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]), - pmf[:, 0, 1:-1], - math_ops.add_n([pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]), - ], axis=-1) - self._pmf = pmf - - cdf = coder_ops.pmf_to_quantized_cdf( - pmf, precision=self.range_coder_precision) - def cdf_getter(*args, **kwargs): - del args, kwargs # ignored - return variable_scope.get_variable( - "quantized_cdf", dtype=dtypes.int32, initializer=cdf, - trainable=False, validate_shape=False, collections=()) - # Need to provide a fake shape here since add_variable insists on it. - self._quantized_cdf = self.add_variable( - "quantized_cdf", shape=(channels, 1), dtype=dtypes.int32, - getter=cdf_getter, trainable=False) - - update_op = state_ops.assign( - self._quantized_cdf, cdf, validate_shape=False) - self.add_update(update_op, inputs=None) - - super(EntropyBottleneck, self).build(input_shape) - - def call(self, inputs, training): - """Pass a tensor through the bottleneck. - - Args: - inputs: The tensor to be passed through the bottleneck. - training: Boolean. If `True`, returns a differentiable approximation of - the inputs, and their likelihoods under the modeled probability - densities. If `False`, returns the quantized inputs and their - likelihoods under the corresponding probability mass function. These - quantities can't be used for training, as they are not differentiable, - but represent actual compression more closely. - - Returns: - values: `Tensor` with the same shape as `inputs` containing the perturbed - or quantized input values. - likelihood: `Tensor` with the same shape as `inputs` containing the - likelihood of `values` under the modeled probability distributions. - - Raises: - ValueError: if `inputs` has different `dtype` or number of channels than - a previous set of inputs the model was invoked with earlier. - """ - inputs = ops.convert_to_tensor(inputs) - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - half = constant_op.constant(.5, dtype=self.dtype) - - # Convert to (channels, 1, batch) format by commuting channels to front - # and then collapsing. - order = list(range(ndim)) - order.pop(channel_axis) - order.insert(0, channel_axis) - values = array_ops.transpose(inputs, order) - shape = array_ops.shape(values) - values = array_ops.reshape(values, (shape[0], 1, -1)) - - # Add noise or quantize. - if training: - noise = random_ops.random_uniform(array_ops.shape(values), -half, half) - values = math_ops.add_n([values, noise]) - elif self.optimize_integer_offset: - values = math_ops.round(values - self._medians) + self._medians - else: - values = math_ops.round(values) - - # Evaluate densities. - # We can use the special rule below to only compute differences in the left - # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1 - # for large x, 0 for small x. Subtracting two numbers close to 0 can be done - # with much higher precision than subtracting two numbers close to 1. - lower = self._logits_cumulative(values - half, stop_gradient=False) - upper = self._logits_cumulative(values + half, stop_gradient=False) - # Flip signs if we can move more towards the left tail of the sigmoid. - sign = -math_ops.sign(math_ops.add_n([lower, upper])) - sign = array_ops.stop_gradient(sign) - likelihood = abs( - math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) - if self.likelihood_bound > 0: - likelihood_bound = constant_op.constant( - self.likelihood_bound, dtype=self.dtype) - # TODO(jballe): Override gradients. - likelihood = math_ops.maximum(likelihood, likelihood_bound) - - # Convert back to input tensor shape. - order = list(range(1, ndim)) - order.insert(channel_axis, 0) - values = array_ops.reshape(values, shape) - values = array_ops.transpose(values, order) - likelihood = array_ops.reshape(likelihood, shape) - likelihood = array_ops.transpose(likelihood, order) - - if not context.executing_eagerly(): - values_shape, likelihood_shape = self.compute_output_shape(inputs.shape) - values.set_shape(values_shape) - likelihood.set_shape(likelihood_shape) - - return values, likelihood - - def compress(self, inputs): - """Compress inputs and store their binary representations into strings. - - Args: - inputs: `Tensor` with values to be compressed. - - Returns: - String `Tensor` vector containing the compressed representation of each - batch element of `inputs`. - """ - with ops.name_scope(self._name_scope()): - inputs = ops.convert_to_tensor(inputs) - if not self.built: - # Check input assumptions set before layer building, e.g. input rank. - self._assert_input_compatibility(inputs) - if self.dtype is None: - self._dtype = inputs.dtype.base_dtype.name - self.build(inputs.shape) - - # Check input assumptions set after layer building, e.g. input shape. - if not context.executing_eagerly(): - self._assert_input_compatibility(inputs) - - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - # Tuple of slices for expanding dimensions of tensors below. - slices = ndim * [None] + [slice(None)] - slices[channel_axis] = slice(None) - slices = tuple(slices) - - # Expand dimensions of CDF to input dimensions, keeping the channels along - # the right dimension. - cdf = self._quantized_cdf[slices[1:]] - num_levels = array_ops.shape(cdf)[-1] - 1 - - # Bring inputs to the right range by centering the range on the medians. - half = constant_op.constant(.5, dtype=self.dtype) - medians = array_ops.squeeze(self._medians, [1, 2]) - offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians - # Expand offsets to input dimensions and add to inputs. - values = inputs + offsets[slices[:-1]] - - # Clip to range and cast to integers. Because we have added .5 above, and - # all values are positive, the cast effectively implements rounding. - values = math_ops.maximum(values, half) - values = math_ops.minimum( - values, math_ops.cast(num_levels, self.dtype) - half) - values = math_ops.cast(values, dtypes.int16) - - def loop_body(tensor): - return coder_ops.range_encode( - tensor, cdf, precision=self.range_coder_precision) - strings = functional_ops.map_fn( - loop_body, values, dtype=dtypes.string, back_prop=False) - - if not context.executing_eagerly(): - strings.set_shape(inputs.shape[:1]) - - return strings - - def decompress(self, strings, shape, channels=None): - """Decompress values from their compressed string representations. - - Args: - strings: A string `Tensor` vector containing the compressed data. - shape: A `Tensor` vector of int32 type. Contains the shape of the tensor - to be decompressed, excluding the batch dimension. - channels: Integer. Specifies the number of channels statically. Needs only - be set if the layer hasn't been built yet (i.e., this is the first input - it receives). - - Returns: - The decompressed `Tensor`. Its shape will be equal to `shape` prepended - with the batch dimension from `strings`. - - Raises: - ValueError: If the length of `shape` isn't available at graph construction - time. - """ - with ops.name_scope(self._name_scope()): - strings = ops.convert_to_tensor(strings) - shape = ops.convert_to_tensor(shape) - if self.built: - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - if channels is None: - channels = self.input_spec.axes[channel_axis] - else: - if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1): - raise ValueError("`shape` must be a vector with known length.") - ndim = shape.shape[0].value + 1 - channel_axis = self._channel_axis(ndim) - input_shape = ndim * [None] - input_shape[channel_axis] = channels - self.build(input_shape) - - # Tuple of slices for expanding dimensions of tensors below. - slices = ndim * [None] + [slice(None)] - slices[channel_axis] = slice(None) - slices = tuple(slices) - - # Expand dimensions of CDF to input dimensions, keeping the channels along - # the right dimension. - cdf = self._quantized_cdf[slices[1:]] - num_levels = array_ops.shape(cdf)[-1] - 1 - - def loop_body(string): - return coder_ops.range_decode( - string, shape, cdf, precision=self.range_coder_precision) - outputs = functional_ops.map_fn( - loop_body, strings, dtype=dtypes.int16, back_prop=False) - outputs = math_ops.cast(outputs, self.dtype) - - medians = array_ops.squeeze(self._medians, [1, 2]) - offsets = math_ops.cast(num_levels // 2, self.dtype) - medians - outputs -= offsets[slices[:-1]] - - if not context.executing_eagerly(): - outputs_shape = ndim * [None] - outputs_shape[0] = strings.shape[0] - outputs_shape[channel_axis] = channels - outputs.set_shape(outputs_shape) - - return outputs - - def visualize(self): - """Multi-channel visualization of densities as images. - - Creates and returns an image summary visualizing the current probabilty - density estimates. The image contains one row for each channel. Within each - row, the pixel intensities are proportional to probability values, and each - row is centered on the median of the corresponding distribution. - - Returns: - The created image summary. - """ - with ops.name_scope(self._name_scope()): - image = self._pmf - image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True) - image = math_ops.cast(image + .5, dtypes.uint8) - image = image[None, :, :, None] - return summary.image("pmf", image, max_outputs=1) - - def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - return input_shape, input_shape diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py deleted file mode 100644 index 798b0234ebcce7df108a0da65d1305502ce0253a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py +++ /dev/null @@ -1,315 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests of EntropyBottleneck class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.coder.python.layers import entropybottleneck - -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.training import gradient_descent - - -class EntropyBottleneckTest(test.TestCase): - - def test_noise(self): - # Tests that the noise added is uniform noise between -0.5 and 0.5. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck() - noisy, _ = layer(inputs, training=True) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - values = np.linspace(-50, 50, 100)[:, None] - noisy, = sess.run([noisy], {inputs: values}) - self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49)) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - - def test_quantization(self): - # Tests that inputs are quantized to full integer values, even after - # quantiles have been updated. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False) - quantized, _ = layer(inputs, training=False) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - values = np.linspace(-50, 50, 100)[:, None] - quantized, = sess.run([quantized], {inputs: values}) - self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6) - - def test_quantization_optimized_offset(self): - # Tests that inputs are not quantized to full integer values after quantiles - # have been updated. However, the difference between input and output should - # be between -0.5 and 0.5, and the offset must be consistent. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True) - quantized, _ = layer(inputs, training=False) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - values = np.linspace(-50, 50, 100)[:, None] - quantized, = sess.run([quantized], {inputs: values}) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - diff = np.ravel(np.around(values) - quantized) % 1 - self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) - self.assertNotEqual(diff[0], 0) - - def test_codec(self): - # Tests that inputs are compressed and decompressed correctly, and quantized - # to full integer values, even after quantiles have been updated. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=60, - optimize_integer_offset=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6) - - def test_codec_optimized_offset(self): - # Tests that inputs are compressed and decompressed correctly, and not - # quantized to full integer values after quantiles have been updated. - # However, the difference between input and output should be between -0.5 - # and 0.5, and the offset must be consistent. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=60, - optimize_integer_offset=True) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - diff = np.ravel(np.around(values) - decoded) % 1 - self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) - self.assertNotEqual(diff[0], 0) - - def test_codec_clipping(self): - # Tests that inputs are compressed and decompressed correctly, and clipped - # to the expected range. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=40) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - expected = np.clip(np.around(values), -40, 40) - self.assertAllClose(expected, decoded, rtol=0, atol=1e-6) - - def test_channels_last(self): - # Test the layer with more than one channel and multiple input dimensions, - # with the channels in the last dimension. - inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=50) - noisy, _ = layer(inputs, training=True) - quantized, _ = layer(inputs, training=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.normal(size=(7, 5, 3, 2)) - noisy, quantized, decoded = sess.run( - [noisy, quantized, decoded], {inputs: values}) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - - def test_channels_first(self): - # Test the layer with more than one channel and multiple input dimensions, - # with the channel dimension right after the batch dimension. - inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", init_scale=50) - noisy, _ = layer(inputs, training=True) - quantized, _ = layer(inputs, training=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.normal(size=(2, 3, 5, 7)) - noisy, quantized, decoded = sess.run( - [noisy, quantized, decoded], {inputs: values}) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - - def test_compress(self): - # Test compression and decompression, and produce test data for - # `test_decompress`. If you set the constant at the end to `True`, this test - # will fail and the log will contain the new test data. - inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", filters=(), init_scale=2) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5 - bitstrings, quantized_cdf, decoded = sess.run( - [bitstrings, layer._quantized_cdf, decoded], {inputs: values}) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - # Set this constant to `True` to log new test data for `test_decompress`. - if False: # pylint:disable=using-constant-test - assert False, (bitstrings, quantized_cdf, decoded) - - # Data generated by `test_compress`. - # pylint:disable=g-inconsistent-quotes,bad-whitespace - bitstrings = np.array([ - b'\x1e\xbag}\xc2\xdaN\x8b\xbd.', - b'\x8dF\xf0%\x1cv\xccllW' - ], dtype=object) - - quantized_cdf = np.array([ - [ 0, 15636, 22324, 30145, 38278, 65536], - [ 0, 19482, 26927, 35052, 42904, 65535], - [ 0, 21093, 28769, 36919, 44578, 65536] - ], dtype=np.int32) - - expected = np.array([ - [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.], - [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.], - [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]], - [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.], - [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.], - [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]] - ], dtype=np.float32) - # pylint:enable=g-inconsistent-quotes,bad-whitespace - - def test_decompress(self): - # Test that decompression of values compressed with a previous version - # works, i.e. that the file format doesn't change across revisions. - bitstrings = array_ops.placeholder(dtypes.string) - input_shape = array_ops.placeholder(dtypes.int32) - quantized_cdf = array_ops.placeholder(dtypes.int32) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", filters=(), dtype=dtypes.float32) - layer.build(self.expected.shape) - layer._quantized_cdf = quantized_cdf - decoded = layer.decompress(bitstrings, input_shape[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - decoded, = sess.run([decoded], { - bitstrings: self.bitstrings, input_shape: self.expected.shape, - quantized_cdf: self.quantized_cdf}) - self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6) - - def test_build_decompress(self): - # Test that layer can be built when `decompress` is the first call to it. - bitstrings = array_ops.placeholder(dtypes.string) - input_shape = array_ops.placeholder(dtypes.int32, shape=[3]) - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.decompress(bitstrings, input_shape[1:], channels=5) - self.assertTrue(layer.built) - - def test_pmf_normalization(self): - # Test that probability mass functions are normalized correctly. - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.build((None, 10)) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - pmf, = sess.run([layer._pmf]) - self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6) - - def test_visualize(self): - # Test that summary op can be constructed. - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.build((None, 10)) - summary = layer.visualize() - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run([summary]) - - def test_normalization(self): - # Test that densities are normalized correctly. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(filters=(2,)) - _, likelihood = layer(inputs, training=True) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - x = np.repeat(np.arange(-200, 201), 1000)[:, None] - likelihood, = sess.run([likelihood], {inputs: x}) - self.assertEqual(x.shape, likelihood.shape) - integral = np.sum(likelihood) * .001 - self.assertAllClose(1, integral, rtol=0, atol=1e-4) - - def test_entropy_estimates(self): - # Test that entropy estimates match actual range coding. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - filters=(2, 3), data_format="channels_last") - _, likelihood = layer(inputs, training=True) - diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) - _, likelihood = layer(inputs, training=False) - disc_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) - bitstrings = layer.compress(inputs) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - diff_entropy, disc_entropy, bitstrings = sess.run( - [diff_entropy, disc_entropy, bitstrings], - {inputs: np.random.normal(size=(1, 10000, 1))}) - codelength = 8 * sum(len(bitstring) for bitstring in bitstrings) - self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0) - self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0) - self.assertGreater(codelength, disc_entropy) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index 3791dae8d7f6b03bc1115bca97811dfc4775c45b..ff846b191a34e3f3b4aa35671ca22b96b963db80 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -150,7 +150,7 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): "matrix must be two dimensional (instead is %d-dimensional)" % matrix_shape.ndims) if matrix_shape[0] != matrix_shape[1]: - raise ValueError("matrix must be be square (instead has shape (%d,%d))" % + raise ValueError("matrix must be square (instead has shape (%d,%d))" % (matrix_shape[0], matrix_shape[1])) dimension = matrix_shape[0].value if dimension is None: diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 5931c8a27996534cca80797e8b840559c124297c..6c9ab6aeb87fd39b22ab4f28d69b432b15899a13 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -219,8 +219,10 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): op_def) #Use Graph's hidden methods to add the op to_graph._record_op_seen_by_control_dependencies(new_op) - for device_function in reversed(to_graph._device_function_stack): + # pylint: disable=protected-access + for device_function in to_graph._device_functions_outer_to_inner: new_op._set_device(device_function(new_op)) + # pylint: enable=protected-access return new_op diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index f56a973f6f80b81697e9f58578e60a2efb90154e..8cfe14205927bf7763cf36fa31012ab10fce995c 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -158,7 +158,7 @@ class CrfTest(test.TestCase): # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) + np.array(1, dtype=np.int64) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], @@ -291,7 +291,7 @@ class CrfTest(test.TestCase): # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) + np.array(1, dtype=np.int64) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 8a7ff61bc8391efe453ee37019c23bd6ccbdf066..2a91dcb63a80016e62d10d1310ca57e3e54434c5 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -548,7 +548,9 @@ def crf_decode(potentials, transition_params, sequence_length): initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # Sequence length is not allowed to be less than zero. - sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) + sequence_length_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_length.dtype), + sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 566cbb246a104d1e6cfc284d220ca8386b8897e1..2e249f5c14ab111ae412ff3288acc25de8d7aa11 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -37,6 +37,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -58,6 +59,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -68,6 +70,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -78,6 +81,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 6edc61b2c2f5fae6ad87db8afea2ef3829d76c95..32f03ca68364e40c6fd6769f05d0566f50119240 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -791,16 +791,17 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { const Tensor* tensor_shard_num; - OP_REQUIRES_OK(ctx, ctx->input("shard_num", &tensor_shard_num)); + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done); int32 shard_num = tensor_shard_num->scalar()(); const Tensor* tensor_incarnation_id; - OP_REQUIRES_OK(ctx, ctx->input("incarnation_id", &tensor_incarnation_id)); + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done); int64 incarnation_id = tensor_incarnation_id->scalar()(); MultiDeviceIterator* iterator; - OP_REQUIRES_OK(ctx, - LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); thread_pool_->Schedule(std::bind( [ctx, iterator, shard_num, incarnation_id](DoneCallback done) { std::vector components; diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 036dc795bb4876f3f96a9e570a9325f1532113ed..ea92191f3e20218d58bb2d0ba1ce7c1120361d45 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -192,27 +192,54 @@ py_test( deps = [ "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:io_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) +py_test( + name = "map_defun_op_test", + size = "small", + srcs = ["map_defun_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:map_defun", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + ], +) + py_test( name = "optimize_dataset_op_test", size = "small", srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:math_ops", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) @@ -234,7 +261,12 @@ cuda_py_test( "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", ], - tags = ["no_windows_gpu"], + tags = [ + "manual", + "no_oss", + "no_windows_gpu", + "notap", + ], ) py_test( @@ -424,8 +456,8 @@ py_test( tags = ["no_pip"], deps = [ ":reader_dataset_ops_test_base", + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -435,6 +467,16 @@ py_test( ], ) +py_library( + name = "stats_dataset_test_base", + srcs = ["stats_dataset_test_base.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "threadpool_dataset_ops_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 30a993b1f7056b9726f524b2279131339c80c5eb..77148aceec7fa90f927a9c009671c2939460877b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util @@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase): def _read_vars(self, model_dir): """Returns (global_step, latest_feature).""" with ops.Graph().as_default() as g: - ckpt_path = saver_lib.latest_checkpoint(model_dir) + ckpt_path = checkpoint_management.latest_checkpoint(model_dir) meta_filename = ckpt_path + '.meta' saver_lib.import_meta_graph(meta_filename) saver = saver_lib.Saver() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index b7025f3802c0c280981df20c86747e49fdf2274f..009e21a34c8df86af6abbb7599dbcfa23ddf90a7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import optimization from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops @@ -79,18 +80,21 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) def testReadFileIgnoreError(self): + def write_string_to_file(value, filename): with open(filename, "w") as f: f.write(value) - filenames = [os.path.join(self.get_temp_dir(), "file_%d.txt" % i) - for i in range(5)] + + filenames = [ + os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5) + ] for filename in filenames: write_string_to_file(filename, filename) dataset = ( dataset_ops.Dataset.from_tensor_slices(filenames).map( - io_ops.read_file, num_parallel_calls=2).prefetch(2).apply( - error_ops.ignore_errors())) + io_ops.read_file, + num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -264,5 +268,91 @@ class MapDatasetBenchmark(test.Benchmark): benchmark("Transformation parallelism evaluation", par_num_calls_series) benchmark("Threadpool size evaluation", par_inter_op_series) + # This benchmark compares the performance of pipeline with multiple chained + # maps with and without map fusion. + def benchmarkChainOfMaps(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkChainOfMaps(chain_length, False) + self._benchmarkChainOfMaps(chain_length, True) + + def _benchmarkChainOfMaps(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x) + if optimize_dataset: + dataset = dataset.apply(optimization.optimize(["map_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map dataset {} chain length: {} Median wall time: {}".format( + opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + +class MapAndFilterBenchmark(test.Benchmark): + + # This benchmark compares the performance of pipeline with multiple chained + # map + filter with and without map fusion. + def benchmarkMapAndFilter(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkMapAndFilter(chain_length, False) + self._benchmarkMapAndFilter(chain_length, True) + + def _benchmarkMapAndFilter(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x + 5).filter( + lambda x: math_ops.greater_equal(x - 5, 0)) + if optimize_dataset: + dataset = dataset.apply( + optimization.optimize(["map_and_filter_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(10): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map and filter dataset {} chain length: {} Median wall time: {}". + format(opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a711325daed12f45e4e533f18ee81adc7dec93be --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -0,0 +1,126 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MapDefunOp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import map_defun +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapDefunTest(test.TestCase): + + def testMapDefun_Simple(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + with self.test_session(): + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefun_MismatchedTypes(self): + + @function.Defun(dtypes.int32) + def fn(x): + return math_ops.cast(x, dtypes.float64) + + with self.test_session(): + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefun_MultipleOutputs(self): + + @function.Defun(dtypes.int32) + def fn(x): + return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) + + with self.test_session(): + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], + [(2,), (2,)]) + expected = [elems, elems * 2 + 3] + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefun_ShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] + self.assertEqual(result.get_shape(), (3, 2)) + + def testMapDefun_PartialShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + elems = array_ops.placeholder(dtypes.int64, (None, 2)) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) + self.assertEqual(result[0].get_shape().as_list(), [None, 2]) + + def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def fn(x, y): + return x, y + + elems1 = array_ops.placeholder(dtypes.int32) + elems2 = array_ops.placeholder(dtypes.int32) + result = map_defun.map_defun(fn, [elems1, elems2], + [dtypes.int32, dtypes.int32], [(), ()]) + with self.test_session() as sess: + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "All inputs must have the same dimension 0."): + sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]}) + + def testMapDefun_RaisesDefunError(self): + + @function.Defun(dtypes.int32) + def fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return array_ops.identity(x) + + elems = constant_op.constant([0, 0, 0, 37, 0]) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) + with self.test_session(): + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index cfef40e1923607406d6587466dc5f533499220eb..ae147b4fa79c5fc8e63e1860f45036709ecc9777 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -17,13 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import optimization +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class OptimizeDatasetTest(test.TestCase): +class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): def testAssertSuffix(self): dataset = dataset_ops.Dataset.from_tensors(0).apply( @@ -44,8 +51,7 @@ class OptimizeDatasetTest(test.TestCase): with self.assertRaisesRegexp( errors.InvalidArgumentError, "Asserted Whoops transformation at offset 0 but encountered " - "Map transformation instead." - ): + "Map transformation instead."): sess.run(get_next) def testAssertSuffixShort(self): @@ -110,6 +116,166 @@ class OptimizeDatasetTest(test.TestCase): "Function .* is not defined."): sess.run(get_next) + @staticmethod + def map_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + + def increment_and_square(x): + y = x + 1 + return y * y + + functions = [identity, increment, increment_and_square] + tests = [] + for i, fun1 in enumerate(functions): + for j, fun2 in enumerate(functions): + tests.append(( + "test_{}_{}".format(i, j), + [fun1, fun2], + )) + for k, fun3 in enumerate(functions): + tests.append(( + "test_{}_{}_{}".format(i, j, k), + [fun1, fun2, fun3], + )) + + swap = lambda x, n: (n, x) + tests.append(( + "swap1", + [lambda x: (x, 42), swap], + )) + tests.append(( + "swap2", + [lambda x: (x, 42), swap, swap], + )) + return tuple(tests) + + @parameterized.named_parameters(*map_functions.__func__()) + def testMapFusion(self, functions): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(["Map", "Prefetch"])) + for function in functions: + dataset = dataset.map(function) + + dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + r = x + for function in functions: + if isinstance(r, tuple): + r = function(*r) # Pass tuple as multiple arguments. + else: + r = function(r) + self.assertAllEqual(r, result) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @staticmethod + def map_and_filter_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + minus_five = lambda x: x - 5 + + def increment_and_square(x): + y = x + 1 + return y * y + + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + is_odd = lambda x: math_ops.equal(x % 2, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + functions = [identity, increment, minus_five, increment_and_square] + filters = [take_all, is_zero, is_odd, greater] + tests = [] + + for x, fun in enumerate(functions): + for y, predicate in enumerate(filters): + tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) + + # Multi output + tests.append(("multiOne", lambda x: (x, x), + lambda x, y: constant_op.constant(True))) + tests.append( + ("multiTwo", lambda x: (x, 2), + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) + return tuple(tests) + + @parameterized.named_parameters(*map_and_filter_functions.__func__()) + def testMapFilterFusion(self, function, predicate): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", + "FilterByLastComponent"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + self._testMapAndFilter(dataset, function, predicate) + + def _testMapAndFilter(self, dataset, function, predicate): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(10): + r = function(x) + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if sess.run(b): + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testAdditionalInputs(self): + a = constant_op.constant(3, dtype=dtypes.int64) + b = constant_op.constant(4, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + function = lambda x: x * x + + def predicate(y): + return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor) + + # We are currently not supporting functions with additional inputs. + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Filter"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + + self._testMapAndFilter(dataset, function, predicate) + + +class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): + + def testLatencyStatsOptimization(self): + + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.from_tensors(1).apply( + optimization.assert_next( + ["LatencyStats", "Map", "LatencyStats", "Prefetch", + "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply( + optimization.optimize(["latency_all_edges"])).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertEqual(1 * 1, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, + "record_latency_TensorDataset/_1", 1) + self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", + 1) + self._assertSummaryHasCount(summary_str, + "record_latency_PrefetchDataset/_6", 1) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 2da6131e8e60ca53723da7f66a7ee52151640129..d66305d7326f78d1e414b6076c1ca6a029baa2f7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -907,6 +907,42 @@ class CopyToDeviceTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testIteratorGetNextAsOptionalOnGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(3) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_elem = iterator_ops.get_next_as_optional(iterator) + elem_has_value_t = next_elem.has_value() + elem_value_t = next_elem.get_value() + + with self.test_session() as sess: + # Before initializing the iterator, evaluating the optional fails with + # a FailedPreconditionError. + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_has_value_t) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_value_t) + + # For each element of the dataset, assert that the optional evaluates to + # the expected value. + sess.run(iterator.initializer) + for i in range(3): + elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) + self.assertTrue(elem_has_value) + self.assertEqual(i, elem_value) + + # After exhausting the iterator, `next_elem.has_value()` will evaluate to + # false, and attempting to get the value will fail. + for _ in range(2): + self.assertFalse(sess.run(elem_has_value_t)) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(elem_value_t) + class MultiDeviceIteratorTest(test.TestCase): diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 851a33dfc849a2d935887def44734aace5dcaf7f..15b342d30f85a05b3827998565ba5f84021ac885 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -173,15 +173,23 @@ class ReadBatchFeaturesTest( for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self.make_batch_feature( + outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, drop_final_batch=True).make_one_shot_iterator().get_next() - for _, tensor in self.outputs.items(): + for _, tensor in outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) + def testIndefiniteRepeatShapeInference(self): + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], num_epochs=None, batch_size=32) + for shape, clazz in zip(nest.flatten(dataset.output_shapes), + nest.flatten(dataset.output_classes)): + if issubclass(clazz, ops.Tensor): + self.assertEqual(32, shape[0]) + class MakeCsvDatasetTest(test.TestCase): @@ -795,6 +803,16 @@ class MakeCsvDatasetTest(test.TestCase): all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) self.assertFalse(all_equal) + def testIndefiniteRepeatShapeInference(self): + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) + dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + class MakeTFRecordDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase): @@ -1002,5 +1020,12 @@ class MakeTFRecordDatasetTest( self._shuffle_test(batch_size, num_epochs, num_parallel_reads, seed=21345) + def testIndefiniteRepeatShapeInference(self): + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, num_epochs=None, batch_size=32) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 3c3f23f9a984c702abfdacf11bef0e5d4066782f..7b9ea191a4524891d1b589e1e228e29241fda7f8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -56,6 +56,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py index a0a1100893c7384b0e2bd9fcfdaa8d3698b95d28..1b6059ccbcc81937696e1b0ebb269f213adbb976 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -19,6 +19,8 @@ from __future__ import print_function import os +from absl.testing import parameterized + from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -26,7 +28,8 @@ from tensorflow.python.platform import test class CacheDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): def setUp(self): self.range_size = 10 @@ -34,88 +37,123 @@ class CacheDatasetSerializationTest( self.num_outputs = self.range_size * self.num_repeats self.cache_file_prefix = 'test' - def ds_fn(self): - return dataset_ops.Dataset.range(self.range_size).cache( - os.path.join(self.get_temp_dir(), - self.cache_file_prefix)).repeat(self.num_repeats) + def make_dataset_fn(self, is_memory): + if is_memory: + filename = '' + else: + filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix) + + def ds_fn(): + return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat( + self.num_repeats) + + return ds_fn def expected_outputs(self): return list(range(self.range_size)) * self.num_repeats - def testCheckpointBeforeOneEpoch(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Generate 5 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointBeforeOneEpochThenRunFewSteps(self): - # Generate 8 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 8 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 8, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, range(8)) - # Restoring from checkpoint and running GetNext should return a - # `AlreadExistsError` now because the lockfile already exists. - with self.assertRaises(errors.AlreadyExistsError): - self.gen_outputs( - self.ds_fn, [], - self.num_outputs - 5, - ckpt_saved=True, - verify_exhausted=False) + if is_memory: + outputs = outputs[:5] + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Restoring from checkpoint and running GetNext should return + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testCheckpointAfterOneEpoch(self): # Generate 15 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointAfterOneEpochThenRunFewSteps(self): - # Generate 18 entries from iterator but save checkpoint after producing - # 15. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 18 entries from iterator but save checkpoint after producing 15. outputs = self.gen_outputs( - self.ds_fn, [15], - 18, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) outputs = list(range(10)) + list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): - # Generate 13 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 13 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 13, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) # Since we ran for more than one epoch, the cache was completely written. @@ -124,65 +162,90 @@ class CacheDatasetSerializationTest( # been completely written. outputs = list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Checkpoint before get_next is called even once. - outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False) self.assertSequenceEqual(outputs, []) outputs = self.gen_outputs( - self.ds_fn, [], - self.num_outputs, - ckpt_saved=True, - verify_exhausted=False) + ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedMidwayWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedMidwayWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint, then produce no elements and checkpoint. outputs.extend( - self.gen_outputs( - self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce rest of the elements. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testUnusedCheckpointError(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testUnusedCheckpointError(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and save ckpt. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) - # Since the complete cache has not been written, a new iterator which does - # not restore the checkpoint will throw an error since there is a partial - # cache shard. - with self.assertRaises(errors.AlreadyExistsError): + if is_memory: outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testIgnoreCheckpointIfCacheWritten(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testIgnoreCheckpointIfCacheWritten(self): # Produce 15 elements and save ckpt. This will write the complete cache. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Build the iterator again but do not restore from ckpt. Since the cache # has already been written we should be able to use it. outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 393f08850b1865180a8b94e9209b2445b54c8b69..3ed4dfb7295ca77c78ce5318bf31e16a354e16a8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase): return os.path.join(self.get_temp_dir(), "iterator") def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) + return checkpoint_management.latest_checkpoint(self.get_temp_dir()) def _save(self, sess, saver): saver.save(sess, self._ckpt_path()) diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index b4945685c1d1062bf416b73f1541f351adf45604..a41d21f8c14ed6bec7626599a5aa7f365765ce8b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import stats_ops -from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -29,28 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTestBase(test.TestCase): - - def _assertSummaryHasCount(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.num) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - def _assertSummaryHasSum(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.sum) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - -class StatsDatasetTest(StatsDatasetTestBase): +class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() @@ -197,7 +176,7 @@ class StatsDatasetTest(StatsDatasetTestBase): class FeatureStatsDatasetTest( - StatsDatasetTestBase, + stats_dataset_test_base.StatsDatasetTestBase, reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testFeaturesStats(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9a13acf8f0ac6690cad8847873768562da795496 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing the input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.platform import test + + +class StatsDatasetTestBase(test.TestCase): + """Base class for testing statistics gathered in `StatsAggregator`.""" + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 1ad021ea037add48afee5bdfda9eea18485eca5d..ad9378dfb9d938c826f994da9bbb89101cfbd872 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -210,6 +210,17 @@ py_library( ], ) +py_library( + name = "map_defun", + srcs = ["map_defun.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], @@ -370,6 +381,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":map_defun", ":optimization", ":prefetching_ops", ":readers", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index a4914f4cde71925af477636c91d98b54ce0cce0e..4835c4e5bd9efded57f19d6a382b145ae1b05e93 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -31,7 +31,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -439,48 +438,6 @@ def unbatch(): return _apply_fn -def _filter_irregular_batches(batch_size): - """Transformation that filters out batches that are not of size batch_size.""" - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - - flattened = _RestructuredDataset( - dataset, - tuple(nest.flatten(dataset.output_types)), - output_classes=tuple(nest.flatten(dataset.output_classes))) - - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] - - return math_ops.equal(first_component_batch_size, tensor_batch_size) - - filtered = flattened.filter(_predicate) - - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) - - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) - - known_shapes = nest.map_structure(_set_first_dimension, - dataset.output_shapes) - return _RestructuredDataset( - filtered, - dataset.output_types, - known_shapes, - output_classes=dataset.output_classes) - - return _apply_fn - - @deprecation.deprecated( None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): @@ -515,10 +472,7 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time - # after 6/30/2018. - batched = dataset.batch(batch_size) - return _filter_irregular_batches(batch_size)(batched) + return dataset.batch(batch_size, drop_remainder=True) return _apply_fn @@ -553,11 +507,9 @@ def padded_batch_and_drop_remainder(batch_size, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` - # any time after 6/30/2018. - batched = dataset.padded_batch( - batch_size, padded_shapes=padded_shapes, padding_values=padding_values) - return _filter_irregular_batches(batch_size)(batched) + return dataset.padded_batch( + batch_size, padded_shapes=padded_shapes, padding_values=padding_values, + drop_remainder=True) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 0d71be66018eeebe60de9deff24ceb6854d209d9..d2c1d0d3620f94f867395a8e2fff0d77a6dc0718 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook @@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access - latest_checkpoint_path = saver_lib.latest_checkpoint( + latest_checkpoint_path = checkpoint_management.latest_checkpoint( self._checkpoint_saver_hook._checkpoint_dir, latest_filename=self._latest_filename) if latest_checkpoint_path: diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py new file mode 100644 index 0000000000000000000000000000000000000000..54d5cd6da068fa5471b7beafcc66d76b5972e7d5 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/map_defun.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +def map_defun(fn, elems, output_dtypes, output_shapes): + """Map a function on the list of tensors unpacked from `elems` on dimension 0. + + Args: + fn: A function (`function.Defun`) that takes a list of tensors and returns + another list of tensors. The output list has the same types as + output_dtypes. The elements of the output list have the same dimension 0 + as `elems`, and the remaining dimensions correspond to those of + `fn_output_shapes`. + elems: A list of tensors. + output_dtypes: A list of dtypes corresponding to the output types of the + function. + output_shapes: A list of `TensorShape`s corresponding to the output + shapes from each invocation of the function on slices of inputs. + + Raises: + ValueError: if any of the inputs are malformed. + + Returns: + A list of `Tensor` objects with the same types as `output_dtypes`. + """ + if not isinstance(elems, list): + raise ValueError("`elems` must be a list of tensors.") + if not isinstance(output_dtypes, list): + raise ValueError("`output_dtypes` must be a list of tensors.") + if not isinstance(output_shapes, list): + raise ValueError("`output_shapes` must be a list of tensors.") + + elems = [ops.convert_to_tensor(e) for e in elems] + output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] + if not all(s.is_fully_defined() for s in output_shapes): + raise ValueError("All fn output shapes must be fully defined.") + return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 0edd7c9fe974784f199c272a649b302e72d8c218..0243c72c70716d31e3ab8b6a3da2270ee0bbc91b 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -633,6 +633,15 @@ class MultiDeviceIterator(object): devices, prefetch_buffer_size=1, source_device="/cpu:0"): + """Constructs a MultiDeviceIterator. + + Args: + dataset: The input dataset to be iterated over. + devices: The list of devices to fetch data to. + prefetch_buffer_size: if > 1, then we setup a buffer on each device + to prefetch into. + source_device: The host device to place the `dataset` on. + """ self._dataset = dataset self._devices = devices self._source_device = source_device @@ -673,7 +682,8 @@ class MultiDeviceIterator(object): i, self._multi_device_iterator_resource, self._incarnation_id, self._source_device_tensor, device, self._dataset.output_shapes, self._dataset.output_types, self._dataset.output_classes) - ds = ds.prefetch(prefetch_buffer_size) + if prefetch_buffer_size > 0: + ds = ds.prefetch(prefetch_buffer_size) with ops.device(device): self._device_iterators.append(ds.make_initializable_iterator()) i += 1 diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index f018dd02e6ae9de69c7364677e1756d1e11bf484..14d69f8d5b29d43649185e689c7a8e6604361bca 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -286,11 +286,14 @@ def make_tf_record_dataset( dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + drop_final_batch = drop_final_batch or num_epochs is None + if parser_fn is None: - if drop_final_batch: - dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) - else: - dataset = dataset.batch(batch_size) + dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) else: # TODO(josh11b): if num_parallel_parser_calls is None, use some function # of num cores instead of map_and_batch's default behavior of one batch. @@ -493,8 +496,13 @@ def make_csv_dataset( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) # Apply batch before map for perf, because map has high overhead relative - # to the size of the computation in each map - dataset = dataset.batch(batch_size=batch_size) + # to the size of the computation in each map. + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch(batch_size=batch_size, + drop_remainder=num_epochs is None) dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) @@ -772,10 +780,12 @@ def make_batched_features_dataset(file_pattern, dataset = dataset.apply(stats_ops.feature_stats("record_stats")) - if drop_final_batch: - dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) - else: - dataset = dataset.batch(batch_size) + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch( + batch_size, drop_remainder=drop_final_batch or num_epochs is None) # Parse `Example` tensors to a dictionary of `Feature` tensors. dataset = dataset.map( diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index 1126f76f5854932bcb6a9550c100768069bbd1cc..d3628d480d31017f835b39f750df40cafa2cc0db 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -25,10 +25,13 @@ py_library( srcs = ["__init__.py"], visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", + "//tensorflow/contrib/distribute/python:multi_worker_strategy", "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/distribute/python:parameter_server_strategy", "//tensorflow/contrib/distribute/python:step_fn", "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 2e2c3be853cc5503c86121c142394d49e5037405..9123ca749b68a1d0066313c77914fa3fb8006a9e 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -19,10 +19,13 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy +from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.training.distribute import * @@ -32,11 +35,14 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AllReduceCrossTowerOps', + 'CollectiveAllReduceStrategy', 'CrossTowerOps', 'DistributionStrategy', 'MirroredStrategy', + 'MultiWorkerMirroredStrategy', 'Monitor', 'OneDeviceStrategy', + 'ParameterServerStrategy', 'ReductionToOneDeviceCrossTowerOps', 'Step', 'StandardInputStep', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index f5d7e24ae2e3aa76efc50f4da93411f66edea651..3159dd154aee2c00b24d36f841584e2582c99b19 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -100,6 +100,23 @@ py_library( ], ) +py_library( + name = "parameter_server_strategy", + srcs = ["parameter_server_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + py_library( name = "one_device_strategy", srcs = ["one_device_strategy.py"], @@ -116,6 +133,24 @@ py_library( ], ) +py_library( + name = "collective_all_reduce_strategy", + srcs = ["collective_all_reduce_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":cross_tower_utils", + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + ], +) + py_library( name = "strategy_test_lib", testonly = 1, @@ -152,6 +187,7 @@ py_library( ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", @@ -207,6 +243,35 @@ py_test( ], ) +py_test( + name = "parameter_server_strategy_test", + srcs = ["parameter_server_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + ":multi_worker_test_base", + ":parameter_server_strategy", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:layers", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:run_config", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], @@ -247,11 +312,11 @@ py_library( ], deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", "//tensorflow/python:distributed_framework_test_lib", - "//tensorflow/python:platform", "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:run_config", + "//third_party/py/numpy", ], ) @@ -272,8 +337,7 @@ py_library( deps = [ ":one_device_strategy", ":values", - "//tensorflow/contrib/tpu", - "//tensorflow/contrib/tpu:tpu_py", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", @@ -281,6 +345,37 @@ py_library( ], ) +py_test( + name = "collective_all_reduce_strategy_test", + srcs = ["collective_all_reduce_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":collective_all_reduce_strategy", + ":combinations", + ":cross_tower_utils", + ":multi_worker_test_base", + ":strategy_test_lib", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:run_config", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "minimize_loss_test_lib", testonly = 1, @@ -451,8 +546,11 @@ py_library( "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", ], ) @@ -487,7 +585,9 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -495,6 +595,7 @@ py_library( cuda_py_test( name = "cross_tower_ops_test", + size = "large", srcs = ["cross_tower_ops_test.py"], additional_deps = [ ":combinations", @@ -509,7 +610,6 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - shard_count = 15, tags = [ "multi_and_single_gpu", "no_pip", diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index fe3df9cbb95308251581005fdb858cccd5d19a1d..bcb977f64073b1d15ef5c872eb0d6b09d5307b54 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -49,17 +49,23 @@ class CheckpointUtilsWithDistributionStrategyTest( def testInitFromCheckpoint(self, distribution, in_tower_mode): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: - v1_value, _, _, _ = checkpoint_utils_test._create_checkpoints( + v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( session, checkpoint_dir) def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) + v2 = variable_scope.get_variable( + "new_var2", [10, 10], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "new_var1", + "var2": "new_var2" }) with self.test_session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(v1_value, self.evaluate(v1)) + self.assertAllEqual(v2_value, self.evaluate(v2)) with ops.Graph().as_default() as g, distribution.scope(): if in_tower_mode: diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..9afcaecf78844b011a9dbc30bb95fa3bfeda8470 --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -0,0 +1,205 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops +from tensorflow.python.training import server_lib + + +# TODO(yuefengz): move this function to a common util file. +def _normalize_cluster_spec(cluster_spec): + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +# TODO(yuefengz): shard the dataset. +# TODO(yuefengz): support in-graph replication. +# TODO(yuefengz): it only works with a cluster without a chief node, maybe +# support chief node? +class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): + """Distribution strategy that uses collective ops for all-reduce. + + It is similar to the MirroredStrategy but it uses collective ops for + reduction. It currently only works for between-graph replication and its + reduction will reduce across all workers. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type="worker", + task_id=0): + """Initializes the object. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type, such as "worker". + task_id: the current task id. + + Raises: + ValueError: if `task_type` is not in the `cluster_spec`. + """ + self._num_gpus_per_worker = num_gpus_per_worker + self._initialize(cluster_spec, task_type, task_id) + + def _initialize(self, cluster_spec, task_type, task_id): + if task_type not in ["chief", "worker"]: + raise ValueError( + "Unrecognized task_type: %r, valid task types are: \"chief\", " + "\"worker\"." % task_type) + if cluster_spec: + self._cluster_spec = _normalize_cluster_spec(cluster_spec) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + num_workers = len(self._cluster_spec.as_dict().get(task_type, [])) + if "chief" in self._cluster_spec.as_dict(): + num_workers += 1 + if not num_workers: + raise ValueError("`task_type` shoud be in `cluster_spec`.") + + # TODO(yuefengz): create a utility to infer chief. + if "chief" in self._cluster_spec.as_dict() and task_type == "chief": + assert task_id == 0 + self._is_chief = True + else: + assert task_type == "worker" + self._is_chief = task_id == 0 + else: + self._cluster_spec = None + self._is_chief = True + worker_device = "" + num_workers = 1 + self._num_workers = num_workers + + if self._num_gpus_per_worker: + local_devices = [ + "%s/device:GPU:%d" % (worker_device, i) + for i in range(self._num_gpus_per_worker) + ] + else: + local_devices = [worker_device] + + self._collective_keys = cross_tower_utils.CollectiveKeys() + super(CollectiveAllReduceStrategy, self).__init__( + devices=local_devices, + cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( + num_workers=num_workers, + num_gpus_per_worker=self._num_gpus_per_worker, + collective_keys=self._collective_keys)) + + # Add a default device so that ops without specified devices will not end up + # on other workers. + if cluster_spec: + self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id) + + def _create_variable(self, next_creator, *args, **kwargs): + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + group_size = len(devices) * self._num_workers + group_key = self._collective_keys.get_group_key(self._devices) + + def _real_mirrored_creator(devices, *args, **kwargs): + """Creates one MirroredVariable on the current worker.""" + index = {} + collective_instance_key = self._collective_keys.get_instance_key( + key_id=kwargs["name"]) + if "initial_value" not in kwargs: + raise ValueError("Initial value must be specified.") + initial_value = kwargs["initial_value"] + if callable(initial_value): + initial_value_fn = initial_value + else: + initial_value_fn = lambda: initial_value + + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) + + # The initial value fn makes sure variables all initialized to + # same values. The first device of the chief worker will send their + # variable values to other devices and other workers. + def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring + with ops.device(device): + initial_value = initial_value_fn() + assert not callable(initial_value) + initial_value = ops.convert_to_tensor(initial_value) + + if self._is_chief and index == 0: + bcast_send = collective_ops.broadcast_send( + initial_value, initial_value.shape, initial_value.dtype, + group_size, group_key, collective_instance_key) + with ops.control_dependencies([bcast_send]): + return array_ops.identity(initial_value) + else: + return collective_ops.broadcast_recv( + initial_value.shape, initial_value.dtype, group_size, + group_key, collective_instance_key) + + kwargs["initial_value"] = _overridden_initial_value_fn + + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + + assert not isinstance(v, values.DistributedVariable) + index[d] = v + return index + + # pylint: disable=protected-access + return mirrored_strategy._create_mirrored_variable( + devices, _real_mirrored_creator, *args, **kwargs) + + def configure(self, session_config=None): + # Use TF_CONFIG to get the cluster spec and the current job. + if not self._cluster_spec: + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", "worker") + task_id = int(task_env.get("index", "0")) + else: + task_type = "worker" + task_id = 0 + + if cluster_spec: + self._initialize(cluster_spec, task_type, task_id) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e54e3b7d7156e87731e6f79aa66262d127232c --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CollectiveAllReduceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class DistributedCollectiveAllReduceStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 0 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ] + } + + def setUp(self): + self._run_options = config_pb2.RunOptions() + self._run_options.experimental.collective_graph_key = 6 + + self._sess_config = config_pb2.ConfigProto() + self._sess_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + # We use a different key_base for each test so that collective keys won't be + # reused. + # TODO(yuefengz, tucker): enable it to reuse collective keys in different + # tests. + DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000 + super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + + def _get_test_object(self, task_type, task_id, num_gpus=0): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus, + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id) + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_start=num_gpus * 100 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base) + distribution._collective_keys = collective_keys + distribution._cross_tower_ops._collective_keys = collective_keys + return distribution, self._workers[task_id].target + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_object(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + + for i in range(10): + b, a = sess.run((before_out, after_out), options=self._run_options) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + def _test_variable_initialization(self, task_type, task_id, num_gpus): + distribution, master_target = self._get_test_object(task_type, task_id, + num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + distribution.scope(): + + def model_fn(): + x = variable_scope.get_variable( + 'x', + shape=(2, 3), + initializer=init_ops.random_uniform_initializer( + 1.0, 10.0, dtype=dtypes.float32)) + return array_ops.identity(x) + + x = distribution.call_for_each_tower(model_fn) + reduced_x = distribution.unwrap( + distribution.reduce( + variable_scope.VariableAggregation.MEAN, x, + destinations='/cpu:0'))[0] + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + x_value, reduced_x_value = sess.run( + [x, reduced_x], options=self._run_options) + self.assertTrue(np.array_equal(x_value, reduced_x_value)) + return np.array_equal(x_value, reduced_x_value) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testVariableInitialization(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_variable_initialization, + self._cluster_spec, + num_gpus=num_gpus) + + +class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testMinimizeLossGraph(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + return + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + self._test_minimize_loss_graph(distribution) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 9a8ea4aa48b8cf4c5906f18d8bddacc224e0b644..120349481ff11dd47d88154c72b37f81b2a1074f 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,6 +46,7 @@ import unittest from absl.testing import parameterized import six +from tensorflow.contrib.cluster_resolver import TPUClusterResolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import multi_worker_strategy from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib @@ -144,7 +145,7 @@ def _augment_with_special_arguments(test_method): """A wrapped test method that treats some arguments in a special way.""" mode = kwargs.pop("mode", "graph") - distribution = kwargs.pop("distribution", None) + distribution = kwargs.get("distribution", None) required_tpu = kwargs.pop("required_tpu", False) required_gpus = kwargs.pop("required_gpus", None) @@ -153,7 +154,6 @@ def _augment_with_special_arguments(test_method): "Do not use `required_gpus` and `distribution` together.") assert required_tpu is False, ( "Do not use `required_tpu` and `distribution` together.") - kwargs["distribution"] = distribution.strategy required_gpus = distribution.required_gpus required_tpu = distribution.required_tpu @@ -189,9 +189,13 @@ def _augment_with_special_arguments(test_method): if mode == "eager": with ops.Graph().as_default(), context.eager_mode(): + if distribution: + kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) elif mode == "graph": with ops.Graph().as_default(), context.graph_mode(): + if distribution: + kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) else: raise ValueError( @@ -321,7 +325,9 @@ default_strategy = NamedDistribution( one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) -tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) +tpu_strategy = NamedDistribution( + "TPU", lambda: tpu_lib.TPUStrategy(TPUClusterResolver("")), + required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index b0baf0dad1d55eafac5338d1eb43465927e428a1..9b5534393edf32145dd5328407d365ac0676879b 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -28,18 +28,37 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_util +def check_destinations(destinations): + """Checks whether `destinations` is not None and not empty. + + Args: + destinations: a DistributedValues, Variable, string or a list of strings. + + Returns: + Boolean indicating whether `destinations` is not None and not empty. + """ + # Calling bool() on a ResourceVariable is not allowed. + if isinstance(destinations, resource_variable_ops.ResourceVariable): + return bool(destinations.device) + return bool(destinations) + + def validate_destinations(destinations): - if not isinstance(destinations, - (value_lib.DistributedValues, six.string_types, list)): + if not isinstance( + destinations, + (value_lib.DistributedValues, resource_variable_ops.ResourceVariable, + six.string_types, list)): raise ValueError("destinations must be one of a `DistributedValues` object," - " a device string, a list of device strings or None") + " a tf.Variable object, a device string, a list of device " + "strings or None") - if not destinations: + if not check_destinations(destinations): raise ValueError("destinations can not be empty") @@ -59,6 +78,8 @@ def _validate_value_destination_pairs(value_destination_pairs): def get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) + elif isinstance(destinations, resource_variable_ops.ResourceVariable): + return [destinations.device] elif isinstance(destinations, six.string_types): return [device_util.resolve(destinations)] else: @@ -225,7 +246,10 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): super(ReductionToOneDeviceCrossTowerOps, self).__init__() def _reduce(self, aggregation, per_device_value, destinations): - devices = get_devices_from(destinations or per_device_value) + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_device_value) reduce_to_device = self.reduce_to_device or devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, self.accumulation_fn, aggregation) @@ -243,9 +267,9 @@ def _group_value_by_device(per_device_values): This grouping is needed to call the all-reduce library because it expects a list of the following form: - [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... - (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... - (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], + [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], + [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], ... ] @@ -266,7 +290,10 @@ def _group_value_by_device(per_device_values): return grouped -def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): +def _ungroup_and_make_mirrored(grouped_reduced, + destinations, + aggregation, + num_between_graph_workers=1): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before @@ -279,6 +306,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): destinations: a list of device strings for returned Mirrored objects. aggregation: Indicates how a variable will be aggregated. Accepted values are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. + num_between_graph_workers: number of workers in the between-graph + replication. Returns: a list of Mirrored objects. @@ -287,7 +316,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): for d, per_device_reduced in enumerate(grouped_reduced): for i, (v, _) in enumerate(per_device_reduced): if aggregation == vs.VariableAggregation.MEAN: - index[i][destinations[d]] = v / len(destinations) + index[i][destinations[d]] = v / ( + len(destinations) * num_between_graph_workers) else: index[i][destinations[d]] = v return [value_lib.Mirrored(v) for v in index] @@ -508,7 +538,10 @@ class AllReduceCrossTowerOps(CrossTowerOps): logging.WARN, "Efficient allreduce is not supported for IndexedSlices.", 10) - devices = get_devices_from(destinations or per_device_value) + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, math_ops.add_n, aggregation) @@ -534,12 +567,12 @@ class AllReduceCrossTowerOps(CrossTowerOps): def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" - logging.info( - "batch_all_reduce invoked for batches size = %d with " + logging.log_first_n( + logging.INFO, "batch_all_reduce invoked for batches size = %d with " "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d", len(per_device_values), - self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) + "agg_small_grads_max_group = %d" % + (len(per_device_values), self._all_reduce_alg, self._num_packs, + self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) @@ -644,12 +677,13 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" - logging.info( + logging.log_first_n( + logging.INFO, "distributed batch_all_reduce invoked for batches size = %d with " "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " - "and agg_small_grads_max_group = %d", len(per_device_values), - self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) + "and agg_small_grads_max_group = %d" % + (len(per_device_values), self._all_reduce_spec, self._num_packs, + self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = sorted(per_device_values[0].devices) device_grads = _group_value_by_device(per_device_values) @@ -692,6 +726,102 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): aggregation) +# TODO(yuefengz): support in-graph collective all-reduce. +class CollectiveAllReduce(CrossTowerOps): + """All-reduce cross tower ops using collective ops. + + In the between-graph replicated training, it will still do all-reduces across + all workers and then put results on the right destinations. + """ + + def __init__(self, + num_workers=1, + num_gpus_per_worker=0, + all_reduce_merge_scope=1, + collective_keys=None): + """Initializes the object. + + Args: + num_workers: number of workers in the between-graph replicated training. + num_gpus_per_worker: number of GPUs per worker. + all_reduce_merge_scope: size of groups into which to partition consecutive + gradients grouped under a common 'allreduce' name scope. This is useful + for some optimization of collective ops. + collective_keys: an optional CollectiveKey object. + """ + self._num_workers = num_workers + self._num_gpus_per_worker = num_gpus_per_worker + self._all_reduce_merge_scope = all_reduce_merge_scope + self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys( + ) + super(CollectiveAllReduce, self).__init__() + + # TODO(yuefengz, tucker): is index slices supported by collective ops? + def _reduce(self, aggregation, per_device_value, destinations): + all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] + if destinations is None or _devices_match(per_device_value, destinations): + return all_reduced + else: + index = {} + for d in get_devices_from(destinations): + # pylint: disable=protected-access + if d in all_reduced._index: + index[d] = all_reduced._index[d] + else: + with ops.device(d): + index[d] = array_ops.identity(list(all_reduced._index.values())[0]) + return value_lib.Mirrored(index) + + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self._reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, aggregation, per_device_values): + """All-reduce across all workers in a batch.""" + if context.executing_eagerly(): + raise ValueError("Eager mode with collective ops is not supported yet.") + + logging.log_first_n( + logging.INFO, "Collective All-reduce invoked with batches size = %d, " + "num_workers = %d" % (len(per_device_values), self._num_workers), 10) + + grouped_by_tower = _group_value_by_device(per_device_values) + + grouped_by_var = list(zip(*grouped_by_tower)) + # grouped_by_var is grouped by variables and takes the following format: + # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..), + # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..), + # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..), + # ... + # ] + chunked_gv = [ + grouped_by_var[x:x + self._all_reduce_merge_scope] + for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope) + ] + + reduced_gv_list = [] + for chunk in chunked_gv: + with ops.name_scope("allreduce"): + for grad_and_vars in chunk: + scaled_grads = [g for g, _ in grad_and_vars] + collective_reduced = cross_tower_utils.build_collective_reduce( + scaled_grads, self._num_workers, self._collective_keys, "Add", + "Id") + result = [] + for (_, v), g in zip(grad_and_vars, collective_reduced): + result.append([g, v]) + reduced_gv_list.append(result) + + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return _ungroup_and_make_mirrored( + new_tower_grads, + per_device_values[0].devices, + aggregation, + num_between_graph_workers=self._num_workers) + + _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 6a780ff60ffcd59d416278bfde6d005d7ad37a68..aec53b01d7a089fec08eec6ea43373a2cd8267d6 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -21,13 +21,17 @@ from __future__ import print_function import itertools from absl.testing import parameterized +import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -376,5 +380,166 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, self._testReductionAndBroadcast(cross_tower_ops, distribution) +class MultiWorkerCollectiveAllReduceTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 100000 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + "fake_worker_0", "fake_worker_1", "fake_worker_2" + ] + } + + def setUp(self): + super(MultiWorkerCollectiveAllReduceTest, self).setUp() + # Reusing keys are not supported well. So we have to give a different + # collective key base for different tests. + MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 + + def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_start=num_gpus * 100 + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + MultiWorkerCollectiveAllReduceTest.collective_key_base) + if local_mode: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = ["/device:GPU:%d" % i for i in range(num_gpus)] + else: + devices = ["/device:CPU:0"] + return collective_all_reduce_ops, devices, "local" + else: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 3, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = [ + "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) + for i in range(num_gpus) + ] + else: + devices = ["/job:%s/task:%d" % (task_type, task_id)] + return collective_all_reduce_ops, devices, self._workers[task_id].target + + def _assert_values_equal(self, left, right, sess): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_values_equal(l, r, sess) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(set(left.devices), set(right.devices)) + + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 6 + + left_values = np.array( + sess.run(list(left._index.values()), options=run_options)).flatten() + right_values = np.array(list(right._index.values())).flatten() + self.assertEqual(len(left_values), len(right_values)) + for l, r in zip(left_values, right_values): + self.assertEqual(l, r) + + def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + collective_all_reduce, devices, master_target = self._get_test_objects( + task_type, task_id, num_gpus, local_mode=local_mode) + if local_mode: + num_workers = 1 + worker_device = None + else: + num_workers = len(self._workers) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + with ops.Graph().as_default(), \ + ops.device(worker_device), \ + self.test_session(target=master_target) as sess: + # Collective ops doesn't support scalar tensors, so we have to construct + # 1-d tensors. + values = [constant_op.constant([float(d)]) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = np.array([(len(devices) - 1.) / 2.]) + + values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = np.array([mean[0] + 1.]) + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device, + destinations=destinations), + _fake_mirrored(mean, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device, + destinations=destinations), + _fake_mirrored(mean * len(devices) * num_workers, destinations or + per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or + per_device), sess) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2) + ], sess) + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean * len(devices) * num_workers, d1 or + per_device), + _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or + per_device_2) + ], sess) + + return True + + @combinations.generate( + combinations.combine(mode=["graph"], num_gpus=[0, 1, 2])) + def testReductionDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients(self._test_reduction, self._cluster_spec, + num_gpus) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 2bb088e704c584598b863b1b836166af2a5bb12c..24cb08fb48f832572da5ae2113e6c224557c6a81 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -19,13 +19,16 @@ from __future__ import division from __future__ import print_function import collections as pycoll +import threading from tensorflow.contrib import nccl from tensorflow.contrib.all_reduce.python import all_reduce from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -218,6 +221,146 @@ def split_grads_by_size(threshold_size, device_grads): return small_grads, large_grads +# threading.Lock() cannot be pickled and therefore cannot be a field of +# CollectiveKeys. +_lock = threading.Lock() + + +# TODO(yuefengz): use random key starts to avoid reusing keys? +class CollectiveKeys(object): + """Class that manages collective keys. + + We need to manage three different keys for collective: + + *Group key*: an integer key to identify the set of cooperative devices. + Collective ops work under the same set of devices must using the same group + key. + + *Instance key*: an integer key to identify the set of same counterpart of + tensors on different devices in a device group that need to be all-reduced. + + "Graph key": an integer key that is unique key graph. This is used to support + multiple graphs per client session. It must be non-zero and set in the + `config` argument of each call to `session.run`. + """ + + def __init__(self, + group_key_start=1, + instance_key_start=100, + instance_key_with_id_start=10000): + """Initializes the object. + + Args: + group_key_start: the starting integer of group key. + instance_key_start: the starting integer of instance key. + instance_key_with_id_start: the starting integer of instance key that is + recorded with an id. + """ + self._group_key = group_key_start + self._group_key_table = dict() + + # For instance keys with ids + self._instance_key_id_to_key_table = dict() + self._instance_key_with_id_counter = instance_key_with_id_start + + # For instance keys without ids + self._instance_key_start = instance_key_start + + self._thread_local = threading.local() + + def _get_thread_local_object(self): + # We make instance key without key ids thread local so that it will work + # with MirroredStrategy and distribute coordinator. + if not hasattr(self._thread_local, 'instance_key'): + self._thread_local.instance_key = self._instance_key_start + return self._thread_local + + def get_group_key(self, devices): + """Returns a group key for the set of devices. + + Args: + devices: list of strings naming devices in a collective group. + + Returns: + int key uniquely identifying the set of device names. + """ + parsed = [pydev.DeviceSpec.from_string(d) for d in devices] + # In the between-graph replicated training, different workers need to get + # the same device key. So we remove the task_type and task_id from the + # devices. + # TODO(yuefengz): in the in-graph replicated training, we need to include + # task_type and task_id. + names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) + key_id = ','.join(names) + with _lock: + if key_id not in self._group_key_table: + new_key = self._group_key + self._group_key += 1 + self._group_key_table[key_id] = new_key + return self._group_key_table[key_id] + + def get_instance_key(self, key_id=None): + """Returns a new instance key for use in defining a collective op. + + Args: + key_id: optional string. If set, key will be recorded and the same key + will be returned when the same key_id is provided. If not, an increasing + instance key will be returned. + """ + if key_id: + with _lock: + if key_id not in self._instance_key_id_to_key_table: + self._instance_key_with_id_counter += 1 + self._instance_key_id_to_key_table[key_id] = ( + self._instance_key_with_id_counter) + return self._instance_key_id_to_key_table[key_id] + else: + v = self._get_thread_local_object().instance_key + self._get_thread_local_object().instance_key += 1 + return v + + +def build_collective_reduce(input_tensors, + num_workers, + collective_keys, + reduction_op='Add', + unary_op='Id'): + """Build a subgraph that does one full all-reduce, using the collective Op. + + Args: + input_tensors: tensors within a single worker graph that are to be reduced + together; must be one per device. + num_workers: total number of workers with identical independent graphs that + will be doing this same reduction. The reduction will actually include + the corresponding tensors at all these workers. + collective_keys: a CollectiveKeys object. + reduction_op: string naming the reduction op. + unary_op: string naming the unary final op. + + Returns: + An array of final tensors, one per device, computed by the full reduction. + + Raises: + ValueError: There must be at least two tensors over all the workers. + """ + group_size = len(input_tensors) * num_workers + if group_size < 2: + raise ValueError('num_workers * len(input_tensors) must be 2 or greater') + devices = [t.device for t in input_tensors] + num_devices = len(devices) + group_key = collective_keys.get_group_key(devices) + instance_key = collective_keys.get_instance_key() + out_tensors = [] + subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec + for d in range(num_devices): + with ops.device(devices[d]): + reduce_op = collective_ops.all_reduce( + input_tensors[d], group_size, group_key, instance_key, reduction_op, + unary_op, subdiv_offsets) + out_tensors.append(reduce_op) + return out_tensors + + def sum_grad_and_var_all_reduce(grad_and_vars, num_workers, alg, @@ -253,10 +396,10 @@ def sum_grad_and_var_all_reduce(grad_and_vars, else: raise ValueError('unsupported all_reduce alg: ', alg) - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index 34410a6470185ac2821bc6a59de9230ff478aeb6..a0bb144b7c7e73b051e41ab93086cbb5b6a852cc 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -96,7 +96,8 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, # TODO(isaprykin): Work around the colocate_with error. dnn_optimizer=adagrad.AdagradOptimizer(0.001), linear_optimizer=adagrad.AdagradOptimizer(0.001), - config=run_config.RunConfig(train_distribute=distribution)) + config=run_config.RunConfig( + train_distribute=distribution, eval_distribute=distribution)) num_steps = 10 estimator.train(train_input_fn, steps=num_steps) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 00c25c7a2482a559c8b94ff3be86c4961dfb439f..44a69ed23a4e00ab81d5b51ae0c14550bd493f14 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -59,7 +59,8 @@ def build_model_fn_optimizer(): def main(_): distribution = tf.contrib.distribute.MirroredStrategy( ["/device:GPU:0", "/device:GPU:1"]) - config = tf.estimator.RunConfig(train_distribute=distribution) + config = tf.estimator.RunConfig(train_distribute=distribution, + eval_distribute=distribution) def input_fn(): features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) @@ -70,7 +71,7 @@ def main(_): model_fn=build_model_fn_optimizer(), config=config) estimator.train(input_fn=input_fn, steps=10) - eval_result = estimator.evaluate(input_fn=input_fn) + eval_result = estimator.evaluate(input_fn=input_fn, steps=10) print("Eval result: {}".format(eval_result)) def predict_input_fn(): diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index 2b05884b9b93470ef9a764cbedbc91bd3912c611..518ec9c4232465c3ecd0e4161f707dac499430c7 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -57,7 +57,8 @@ def main(args): # tf.Estimator that utilizes the DistributionStrategy. strategy = tf.contrib.distribute.MirroredStrategy( ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(train_distribute=strategy) + config = tf.estimator.RunConfig( + train_distribute=strategy, eval_distribute=strategy) keras_estimator = tf.keras.estimator.model_to_estimator( keras_model=model, config=config, model_dir=model_dir) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 75ecd90dcffa7a786b78238ef453c4c8e4346afa..ec0ca6879cffb9214adec15058cfb7293d347b25 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -12,33 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Keras Sequential and Functional models.""" +"""Tests for tf.keras models using DistributionStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os - import numpy as np from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop + _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as +# part of the tf.keras unit tests suite. def simple_sequential_model(): model = keras.models.Sequential() model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) @@ -84,7 +91,7 @@ def get_ds_test_input_fn(): return dataset -class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): self._base_dir = os.path.join(self.get_temp_dir(), @@ -107,7 +114,8 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=dist, + eval_distribute=dist) with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -144,5 +152,416 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + def test_keras_optimizer_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=keras.optimizers.rmsprop(lr=0.01)) + + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator(keras_model=keras_model, + config=config) + with self.assertRaisesRegexp(ValueError, + 'Only TensorFlow native optimizers are ' + 'supported with DistributionStrategy.'): + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + +class TestWithDistributionStrategy(test.TestCase): + + def test_validating_dataset_input_tensors_with_shape_mismatch(self): + with self.test_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + with self.test_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_calling_model_on_same_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.predict(dataset, steps=2) + + def test_fit_eval_and_predict_methods_on_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + # Test with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + + def test_raise_error_for_stateful_metrics(self): + + class ExampleStatefulMetric(keras.layers.Layer): + + def __init__(self, name='true_positives', **kwargs): + super(ExampleStatefulMetric, self).__init__(name=name, **kwargs) + self.stateful = True + + def __call__(self, y_true, y_pred): + return y_pred - y_true + + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', ExampleStatefulMetric()] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + with self.assertRaisesRegexp( + NotImplementedError, 'Stateful metrics are not supported with ' + 'DistributionStrategy.'): + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + def test_unsupported_features(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + NotImplementedError, 'sample_weight is currently not supported when ' + 'using DistributionStrategy.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument. + with self.assertRaisesRegexp( + ValueError, 'you should specify the `steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.predict(dataset, verbose=0) + + def test_calling_with_unsupported_predefined_callbacks(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + def schedule(_): + return 0.001 + with self.assertRaisesRegexp(ValueError, + 'LearningRateScheduler callback is not ' + 'supported with DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp(ValueError, + 'ReduceLROnPlateau callback is not ' + 'supported with DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + with self.assertRaisesRegexp(ValueError, + 'histogram_freq in the TensorBoard callback ' + 'is not supported when using ' + 'DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) + + def test_dataset_input_shape_validation(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, distribute=strategy) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have 2 dimensions'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + def test_learning_phase_value(self): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.test_session(): + x = keras.layers.Input(shape=(16,), name='input') + y = keras.layers.Dense(16)(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.random.rand(10, 16) + targets = np.ones((10, 16), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(8) + + hist = model.fit(dataset, epochs=5, steps_per_epoch=20, verbose=1) + self.assertEqual(hist.history['acc'][0], 1) + + evaluate_output = model.evaluate(dataset, steps=20) + self.assertEqual(evaluate_output[1], 0) + + predict_output = model.predict(dataset, steps=1) + self.assertNotEqual(np.mean(predict_output), 0) + + +class LossMaskingWithDistributionStrategyTest(test.TestCase): + + def test_masking(self): + with self.test_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=strategy) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class NormalizationLayerWithDistributionStrategyTest(test.TestCase): + + def test_batchnorm_correctness(self): + with self.test_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0', + '/device:GPU:0']) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=strategy) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = dataset.batch(32) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class CorrectnessWithDistributionStrategyTest(test.TestCase): + + def test_correctness(self): + with self.test_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_shape=(1,))) + + # With DistributionStrategy + dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset_with = dataset_with.batch(32) + strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', + '/device:GPU:0'], + prefetch_on_device=False) + + model.compile(loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + distribute=strategy) + model.fit(x=dataset_with, epochs=1, steps_per_epoch=310) + wts_with_ds = model.get_weights() + + x_predict = [[1], [2], [3], [4]] + predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict, + x_predict)) + predict_dataset_with = predict_dataset_with.batch(2) + predict_with_ds = model.predict(predict_dataset_with, steps=1) + predict_with_ds = np.reshape(predict_with_ds, (4, 1)) + + # Without DistributionStrategy + dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train, + y_train)) + dataset_without = dataset_without.batch(64) + + model.compile(loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5)) + model.fit(x=dataset_without, epochs=1, steps_per_epoch=310) + wts_without_ds = model.get_weights() + + x_predict = [[1], [2], [3], [4]] + predict_dataset_without = dataset_ops.Dataset.from_tensor_slices(( + x_predict, x_predict)) + predict_dataset_without = predict_dataset_without.batch(4) + predict_without_ds = model.predict(predict_dataset_without, steps=1) + + # Verify that the weights are the same within some limits of tolerance. + np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3) + # Verify that the predicted outputs are the same within some limits of + # tolerance. + np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 6c6bf143098c1bba64d47efce1bfface7682683d..2f3d6bdd3f4e4bc7352d7b378ed40b930608ef08 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -19,7 +19,6 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import combinations from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test @@ -183,7 +182,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _dataset_fn(): dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) # Want to produce a fixed, known shape, so drop remainder when batching. - dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + dataset = dataset.batch(4, drop_remainder=True) return dataset def _expected_fn(num_batches): diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index dcbc6b0878b89cbb5b9779de315429e6f9478d15..3c1760c03c6001f421d208ad1bf9cde9a1033081 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -20,7 +20,6 @@ from __future__ import print_function import contextlib import threading -import six from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import shared_variable_creator @@ -28,13 +27,17 @@ from tensorflow.contrib.distribute.python import values from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import tape +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import coordinator from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest # TODO(josh11b): Replace asserts in this file with if ...: raise ... @@ -60,6 +63,233 @@ class _RequestedStop(Exception): pass +# Make _call_for_each_tower and _reduce_non_distributed_value not members of +# MirroredStrategy so that they are generally not allowed to use anything +# specific to MirroredStrategy and thus can be shared with other distribution +# strategies. + + +# TODO(yuefengz): maybe create a common class for those who need to call this +# _call_for_each_tower. +def _call_for_each_tower(distribution, fn, *args, **kwargs): + """Run `fn` in separate threads, once per tower/worker device. + + Args: + distribution: the DistributionStrategy object. + fn: function to run (will be run once per device, each in its own thread). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. + + Returns: + Merged return value of `fn` across all towers. + + Raises: + RuntimeError: If fn() calls get_tower_context().merge_call() a different + number of times from the available devices. + """ + run_concurrently = kwargs.pop("run_concurrently", True) + if not context.executing_eagerly(): + # Lots of TF library code isn't thread-safe in graph mode, and + # there is little to be gained by turning on multithreading when + # constructing a graph. + run_concurrently = False + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + elif run_concurrently is None: + run_concurrently = True + + coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + + # TODO(isaprykin): Create these threads once instead of during every run() + # call. + threads = [] + for index, d in enumerate(distribution.worker_devices): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access + distribution, coord, d, variable_creator_fn, fn, + *values.select_device(d, args), **values.select_device(d, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredTowerThread + # (`MTT`) threads. The execution waits until + # `MTT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_tower_context().merge_call()` is called. If `fn` is + # complete, then `MTT.done` is set to True. Otherwise, arguments + # of `get_tower_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_tower_context().merge_call` are then set to `MTT.merge_result`. + # Each such `get_tower_context().merge_call` call returns the + # `MTT.merge_result` for that thread when `MTT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some towers made a different number of " + "tower_context().merge_call() calls.") + # get_tower_context().merge_call() case + merge_args = values.regroup({t.device: t.merge_args for t in threads}) + merge_kwargs = values.regroup( + {t.device: t.merge_kwargs for t in threads}) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn(distribution, *merge_args, + **merge_kwargs) + for t in threads: + t.merge_result = values.select_device(t.device, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup({t.device: t.main_result for t in threads}) + + +def _reduce_non_distributed_value(distribution, aggregation, value, + destinations): + """Reduce a non-DistributedValue `value` to `destinations`.""" + if isinstance(value, values.DistributedValues): + raise ValueError("You are passing a `DistributedValue` to " + "`_reduce_non_distributed_value`, which is not allowed.") + + # If the same value is present on all towers then the PerDevice value will + # be a single value. We also handle the case when `value` is a single value + # and equal to 0. + if value == 0: + return 0 + # If the aggregation type is MEAN, then this essentially means that the same + # value should be on all destinations. + if aggregation == variable_scope.VariableAggregation.MEAN: + return distribution.broadcast(value, destinations) + + cross_tower_ops_lib.validate_destinations(destinations) + # We do not support an aggregation type of SUM if the value is the same across + # all towers. We call this as part of assign functions for MirroredVariables + # and summing up identical values across towers is not clearly defined. + if (len(distribution.worker_devices) != 1 or + not cross_tower_ops_lib.check_destinations(destinations)): + raise ValueError("A non-DistributedValues value cannot be reduced with the " + "given aggregation.") + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return values.Mirrored(value_updates) + + +def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # Get synchronization value + synchronization = kwargs.get("synchronization", + variable_scope.VariableSynchronization.ON_WRITE) + if synchronization == variable_scope.VariableSynchronization.NONE: + raise ValueError("`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please" + " change the `synchronization` for variable: " + + kwargs["name"]) + elif synchronization == variable_scope.VariableSynchronization.ON_READ: + # Variables that are to be synced on read are tower local. + is_tower_local = True + kwargs["trainable"] = False + elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or + synchronization == variable_scope.VariableSynchronization.AUTO): + # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. + is_tower_local = False + else: + raise ValueError("Invalid variable synchronization mode: " + + synchronization + " for variable: " + kwargs["name"]) + + # Get aggregation value + aggregation = kwargs.pop("aggregation", + variable_scope.VariableAggregation.NONE) + if aggregation not in [ + variable_scope.VariableAggregation.NONE, + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ]: + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = real_mirrored_creator(devices, *args, **kwargs) + + if is_tower_local: + result = values.TowerLocalVariable(index, index[devices[0]], aggregation) + else: + result = values.MirroredVariable(index, index[devices[0]], aggregation) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + class MirroredStrategy(distribute_lib.DistributionStrategy): """Mirrors vars to distribute across multiple devices on a single machine. @@ -94,54 +324,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) - # Get synchronization value - synchronization = kwargs.get( - "synchronization", variable_scope.VariableSynchronization.ON_WRITE) - if synchronization == variable_scope.VariableSynchronization.NONE: - raise ValueError("`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please" - " change the `synchronization` for variable: " + - kwargs["name"]) - elif synchronization == variable_scope.VariableSynchronization.ON_READ: - # Variables that are to be synced on read are tower local. - is_tower_local = True - kwargs["trainable"] = False - elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or - synchronization == variable_scope.VariableSynchronization.AUTO): - # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. - is_tower_local = False - else: - raise ValueError("Invalid variable synchronization mode: " + - synchronization + " for variable: " + kwargs["name"]) - - # Get aggregation value - aggregation = kwargs.pop("aggregation", - variable_scope.VariableAggregation.NONE) - if aggregation not in [ - variable_scope.VariableAggregation.NONE, - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN - ]: - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): + def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring index = {} for i, d in enumerate(devices): with ops.device(d): @@ -165,149 +351,71 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v + return index - if is_tower_local: - result = values.TowerLocalVariable(index, index[devices[0]], - aggregation) - else: - result = values.MirroredVariable(index, index[devices[0]], aggregation) - - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - l.remove(v) - g.add_to_collections(collections, result) - return result + return _create_mirrored_variable(devices, _real_mirrored_creator, *args, + **kwargs) def distribute_dataset(self, dataset_fn): return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), self._devices, self._prefetch_on_device) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + + ctx = values.MultiStepContext() + def body(i, *args): + """A wrapper around `fn` to create the while loop body.""" + del args + fn_result = fn(ctx, iterator.get_next()) + for (name, output) in ctx.last_step_outputs.items(): + # Convert all outputs to tensors, potentially from `DistributedValues`. + ctx.last_step_outputs[name] = self.unwrap(output) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + with ops.control_dependencies([fn_result]): + return [i + 1] + flat_last_step_outputs + + cond = lambda i, *args: i < iterations + i = constant_op.constant(0) + loop_result = control_flow_ops.while_loop( + cond, body, [i] + initial_loop_values, name="", + parallel_iterations=1, back_prop=False, swap_memory=False, + return_same_structure=True) + + ctx.run_op = control_flow_ops.group(loop_result) + + # Convert the last_step_outputs from a list to the original dict structure + # of last_step_outputs. + last_step_tensor_outputs = loop_result[1:] + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been aggregated, wrap them in a Mirrored + # container, else in a PerDevice container. + if aggregation is variables_lib.VariableAggregation.NONE: + last_step_tensor_outputs_dict[name] = values.regroup( + {d: t for d, t in zip(self._devices, output)}, values.PerDevice) + else: + assert len(output) == 1 + last_step_tensor_outputs_dict[name] = output[0] + + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + return ctx + def _broadcast(self, tensor, destinations): # TODO(josh11b): In eager mode, use one thread per device, or async mode. return self._get_cross_tower_ops().broadcast(tensor, destinations or self._devices) def _call_for_each_tower(self, fn, *args, **kwargs): - """Run `fn` in separate threads, once per tower/worker device. - - Args: - fn: function to run (will be run once per device, each in its own thread). - *args: positional arguments for `fn` - **kwargs: keyword arguments for `fn`. - `"run_concurrently"`: Boolean indicating whether executions of `fn` - can be run concurrently (under eager execution only), defaults to - `True`. - - Returns: - Merged return value of `fn` across all towers. - - Raises: - RuntimeError: If fn() calls get_tower_context().merge_call() a different - number of times for when called for different devices. - """ - run_concurrently = kwargs.pop("run_concurrently", True) - if not context.executing_eagerly(): - # Lots of TF library code isn't thread-safe in graph mode, and - # there is little to be gained by turning on multithreading when - # constructing a graph. - run_concurrently = False - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - elif run_concurrently is None: - run_concurrently = True - - coord = coordinator.Coordinator( - clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every run() - # call. - threads = [] - for index, d in enumerate(self._devices): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = MirroredStrategy._MirroredTowerThread( - self, coord, d, variable_creator_fn, fn, - *values.select_device(d, args), **values.select_device(d, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredTowerThread - # (`MTT`) threads. The execution waits until - # `MTT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_tower_context().merge_call()` is called. If `fn` is - # complete, then `MTT.done` is set to True. Otherwise, arguments - # of `get_tower_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_tower_context().merge_call` are then set to `MTT.merge_result`. - # Each such `get_tower_context().merge_call` call returns the - # `MTT.merge_result` for that thread when `MTT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some towers made a different number of " - "tower_context().merge_call() calls.") - # get_tower_context().merge_call() case - merge_args = values.regroup( - {t.device: t.merge_args for t in threads}) - merge_kwargs = values.regroup( - {t.device: t.merge_kwargs for t in threads}) - # We capture the name_scope of the MTT when we call merge_fn - # to ensure that if we have opened a name scope in the MTT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MTT and assume it is - # the same for all other MTTs. - mtt_captured_name_scope = threads[0].captured_name_scope - with ops.name_scope(mtt_captured_name_scope): - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) - for t in threads: - t.merge_result = values.select_device(t.device, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup({t.device: t.main_result for t in threads}) + return _call_for_each_tower(self, fn, *args, **kwargs) def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. @@ -337,29 +445,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) - if not isinstance(value, values.PerDevice): - if value == 0: - return 0 - if aggregation == variable_scope.VariableAggregation.MEAN: - return self._broadcast(value, destinations) - - cross_tower_ops_lib.validate_destinations(destinations) - if len(self._devices) == 1: - if destinations: - # TODO(anjalisridhar): Moves these methods to a device utility file? - devices = cross_tower_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - with ops.device(devices[0]): - return array_ops.identity(value) - else: - value_updates = {} - for d in devices: - with ops.device(d): - value_updates[d] = array_ops.identity(value) - return values.Mirrored(value_updates) - raise ValueError("A non PerDevice value cannot be reduced with the given " - "aggregation.") - + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerDevice or Mirrored + # values. For example, the same value could be present on all towers in + # which case `value` would be a single value or value could be 0. + return _reduce_non_distributed_value(self, aggregation, value, + destinations) return self._get_cross_tower_ops().reduce( aggregation, value, destinations=destinations) @@ -406,6 +497,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return [val.get(device=d) for d in sorted(val.devices)] return [val] + def value_container(self, val): + return values.value_container(val) + @property def is_single_tower(self): return len(self._devices) == 1 @@ -433,15 +527,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _get_devices_from(self, colocate_with=None): if colocate_with is None: return self._devices - elif isinstance(colocate_with, values.DistributedValues): - # pylint: disable=protected-access - return list(colocate_with._index.keys()) - elif isinstance(colocate_with, six.string_types): - return [device_util.resolve(colocate_with)] - elif isinstance(colocate_with, list): - return [device_util.resolve(d) for d in colocate_with] else: - return colocate_with + return cross_tower_ops_lib.get_devices_from(colocate_with) class _MirroredTowerThread(threading.Thread): """A thread that runs() a function on a device.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 9807ce43515a9f1000f62c279f9dcf16491e4fba..e064cfe37db40a51e18a16c532500415a8b74816 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -25,7 +25,9 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,6 +39,7 @@ from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib @@ -792,8 +795,8 @@ class MirroredVariableUpdateTest(test.TestCase): return mirrored_var.assign(5.0) with self.assertRaisesRegexp( - ValueError, "A non PerDevice value cannot be reduced with the given " - "aggregation."): + ValueError, "A non-DistributedValues value cannot be reduced with " + "the given aggregation."): self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) @test_util.run_in_graph_and_eager_modes(config=config) @@ -838,6 +841,29 @@ class MirroredVariableUpdateTest(test.TestCase): model_fn, run_concurrently=False))) self.assertEquals(0.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) def testAssignAddMirroredVarCrossTowerContext(self): self._skip_eager_if_gpus_less_than(1) @@ -880,6 +906,29 @@ class MirroredVariableUpdateTest(test.TestCase): model_fn, run_concurrently=False))) self.assertEquals(1.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_add(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(6.0, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) def testAssignSubMirroredVarCrossTowerContext(self): self._skip_eager_if_gpus_less_than(1) @@ -922,6 +971,29 @@ class MirroredVariableUpdateTest(test.TestCase): model_fn, run_concurrently=False))) self.assertEquals(4.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_sub(1.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.0, self.evaluate(mirrored_var)) + class MirroredAndTowerLocalVariableInitializerTest(test.TestCase): config = config_pb2.ConfigProto() @@ -974,7 +1046,7 @@ class TowerLocalVariableAssignTest(test.TestCase): def _skip_eager_if_gpus_less_than(self, num_gpus): if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") + self.skipTest("Not enough GPUs available for this test in eager mode.") @test_util.run_in_graph_and_eager_modes(config=config) def testAssignTowerLocalVarSumAggregation(self): @@ -1036,5 +1108,131 @@ class TowerLocalVariableAssignTest(test.TestCase): self.assertEqual(6.0, self.evaluate(dist.read_var(tower_local_var))) +class MockModel(object): + + def __init__(self, two_variables=False): + self.variables = [] + self.variables.append(variable_scope.variable(1.25, name="dummy_var1")) + if two_variables: + self.variables.append(variable_scope.variable(2.0, name="dummy_var2")) + + def __call__(self, factor=2): + x = factor * self.variables[0] + if len(self.variables) > 1: + x += self.variables[1] + return x + + +class MirroredStrategyDefunTest(test.TestCase): + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Not enough GPUs available for this test in eager mode.") + + def _call_and_check(self, model_fn, inputs, expected_result, defuns, + two_variables=False): + cpu_dev = device_util.canonicalize("CPU:0") + gpu_dev = device_util.canonicalize("GPU:0") + devices = [cpu_dev, gpu_dev] + dist = mirrored_strategy.MirroredStrategy(devices) + + with dist.scope(): + mock_model = MockModel(two_variables) + self.evaluate(variables.global_variables_initializer()) + + result = dist.call_for_each_tower(model_fn, mock_model, *inputs, + run_concurrently=False) + for device in devices: + device_result = values.select_device(device, result) + device_expected_result = values.select_device(device, expected_result) + self.assertAllClose(device_expected_result, + self.evaluate(device_result)) + + for defun in defuns: + self.assertEqual(set(mock_model.variables), set(defun.variables)) + + @test_util.run_in_graph_and_eager_modes() + def testVariableInDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def times_two(mock_model): + return mock_model() + + def model_fn(mock_model): + return times_two(mock_model) + + self._call_and_check(model_fn, [], 2.5, [times_two]) + + @test_util.run_in_graph_and_eager_modes() + def testVariableInNestedDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def times_two(mock_model): + return mock_model() + + @function.defun + def two_x_plus_one(mock_model): + return times_two(mock_model) + 1 + + def model_fn(mock_model): + return two_x_plus_one(mock_model) + + self._call_and_check(model_fn, [], 3.5, [times_two, two_x_plus_one]) + + @test_util.run_in_graph_and_eager_modes() + def testTwoVariablesInNestedDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model): + return mock_model() + + @function.defun + def fn2(mock_model): + return fn1(mock_model) + 1 + + def model_fn(mock_model): + return fn2(mock_model) + + self._call_and_check(model_fn, [], 5.5, [fn1, fn2], two_variables=True) + + @test_util.run_in_graph_and_eager_modes() + def testGradientTapeOverNestedDefuns(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model): + return mock_model() + + @function.defun + def fn2(mock_model): + return fn1(mock_model) + 1 + + def model_fn(mock_model): + with backprop.GradientTape(persistent=True) as gtape: + result = fn2(mock_model) + grads = gtape.gradient(result, + [v.get() for v in mock_model.variables]) + return grads + + self._call_and_check(model_fn, [], [2.0, 1.0], [fn1, fn2], + two_variables=True) + + @test_util.run_in_graph_and_eager_modes() + def testPassPerDevice(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model, factor): + return mock_model(factor) + + factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0}) + expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25, + "GPU:0": 3.0 * 1.25}) + self._call_and_check(fn1, [factors], expected_result, [fn1]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index f659be5f42594b275af06435cb0c228e5d594ac9..249de01f0880b02d603687db99692088480f7136 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -20,35 +20,68 @@ from __future__ import print_function import contextlib import copy +import threading +import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session -from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config +from tensorflow.python.platform import test from tensorflow.python.framework import test_util +def create_in_process_cluster(num_workers, num_ps): + """Create an in-process cluster that consists of only standard server.""" + # Leave some memory for cuda runtime. + gpu_mem_frac = 0.7 / num_workers + worker_config = config_pb2.ConfigProto() + worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # Enable collective ops which has no impact on non-collective ops. + # TODO(yuefengz, tucker): removing this after we move the initialization of + # collective mgr to the session level. + worker_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + ps_config = config_pb2.ConfigProto() + ps_config.device_count['GPU'] = 0 + + # Create in-process servers. Once an in-process tensorflow server is created, + # there is no way to terminate it. So we create one cluster per test process. + # We could've started the server in another process, we could then kill that + # process to terminate the server. The reasons why we don't want multiple + # processes are + # 1) it is more difficult to manage these processes; + # 2) there is something global in CUDA such that if we initialize CUDA in the + # parent process, the child process cannot initialize it again and thus cannot + # use GPUs (https://stackoverflow.com/questions/22950047). + return test_util.create_local_cluster( + num_workers, + num_ps=num_ps, + worker_config=worker_config, + ps_config=ps_config, + protocol='grpc') + + class MultiWorkerTestBase(test.TestCase): """Base class for testing multi node strategy and dataset.""" @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - num_workers = 2 - # Leave some memory for cuda runtime. - gpu_mem_frac = 0.7 / num_workers - default_config = config_pb2.ConfigProto() - default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac - - # The local cluster takes some portion of the local GPUs and there is no way - # for the cluster to terminate unless using multiple processes. Therefore, - # we have to only create only one cluster throughout a test process. - workers, _ = test_util.create_local_cluster( - num_workers, num_ps=0, worker_config=default_config) - cls._master_target = workers[0].target + cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0) + + def setUp(self): + # We only cache the session in one test because another test may have a + # different session config or master target. + self._thread_local = threading.local() + self._thread_local.cached_session = None + self._result = 0 + self._lock = threading.Lock() @contextlib.contextmanager - def test_session(self, graph=None, config=None): + def test_session(self, graph=None, config=None, target=None): """Create a test session with master target set to the testing cluster. This overrides the base class' method, removes arguments that are not needed @@ -59,6 +92,7 @@ class MultiWorkerTestBase(test.TestCase): graph: Optional graph to use during the returned session. config: An optional config_pb2.ConfigProto to use to configure the session. + target: the target of session to connect to. Yields: A Session object that should be used as a context manager to surround @@ -78,13 +112,46 @@ class MultiWorkerTestBase(test.TestCase): rewriter_config_pb2.RewriterConfig.OFF) if graph is None: - if self._cached_session is None: # pylint: disable=access-member-before-definition - self._cached_session = session.Session( - graph=None, config=config, target=self._master_target) - sess = self._cached_session + if getattr(self._thread_local, 'cached_session', None) is None: + self._thread_local.cached_session = session.Session( + graph=None, config=config, target=target or self._workers[0].target) + sess = self._thread_local.cached_session with sess.graph.as_default(), sess.as_default(): yield sess else: with session.Session( - graph=graph, config=config, target=self._master_target) as sess: + graph=graph, config=config, target=target or + self._workers[0].target) as sess: yield sess + + def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, + **kwargs): + result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) + if np.all(result): + with self._lock: + self._result += 1 + + def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, + **kwargs): + """Runs several clients for between-graph replication. + + Args: + client_fn: a function that needs to accept `task_type`, `task_id`, + `num_gpus` and returns True if it succeeds. + cluster_spec: a dict specifying jobs in a cluster. + num_gpus: number of GPUs per worker. + *args: will be passed to `client_fn`. + **kwargs: will be passed to `client_fn`. + """ + threads = [] + for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]: + for task_id in range(len(cluster_spec.get(task_type, []))): + t = threading.Thread( + target=self._run_client, + args=(client_fn, task_type, task_id, num_gpus) + args, + kwargs=kwargs) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertEqual(self._result, len(threads)) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index dbd3514aec7d40d9a04dba4bcbc5c14be639aa33..016978cdb3a152bbba0a2e63df1dea4035e32789 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -21,11 +21,14 @@ from __future__ import print_function import six from tensorflow.contrib.distribute.python import values +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest # TODO(josh11b): Replace asserts in this file with if ...: raise ... @@ -66,6 +69,41 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def _broadcast(self, tensor, destinations): return tensor + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + + ctx = values.MultiStepContext() + def body(i, *args): + """A wrapper around `fn` to create the while loop body.""" + del args + fn_result = fn(ctx, iterator.get_next()) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + with ops.control_dependencies([fn_result]): + return [i + 1] + flat_last_step_outputs + + cond = lambda i, *args: i < iterations + i = constant_op.constant(0) + # TODO(priyag): Use max_iterations instead of an explicit counter. + loop_result = control_flow_ops.while_loop( + cond, body, [i] + initial_loop_values, name="", + parallel_iterations=1, back_prop=False, swap_memory=False, + return_same_structure=True) + + ctx.run_op = control_flow_ops.group(loop_result) + + # Convert the last_step_outputs from a list to the original dict structure + # of last_step_outputs. + last_step_tensor_outputs = loop_result[1:] + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + return ctx + def _call_for_each_tower(self, fn, *args, **kwargs): # We don't run `fn` in multiple threads in OneDeviceStrategy. kwargs.pop("run_concurrently", None) @@ -105,6 +143,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def _unwrap(self, value): return [value] + def value_container(self, value): + return value + @property def is_single_tower(self): return True diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c7fd556adaf08926b6f1e327abd25b7c9a42e6 --- /dev/null +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -0,0 +1,358 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a multi-worker ps DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import device_setter +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest + +_LOCAL_CPU = "/device:CPU:0" +_LOCAL_GPU_0 = "/device:GPU:0" + + +def _normalize_cluster_spec(cluster_spec): + """Makes `cluster_spec` into a `ClusterSpec` object.""" + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +# TODO(yuefengz): maybe cache variables on local CPU. +# TODO(yuefengz): we may want to set session options to disallow communication +# between workers. +class ParameterServerStrategy(distribute_lib.DistributionStrategy): + """A parameter server DistributionStrategy. + + This strategy class works for both local training and between-graph replicated + training for multiple workers. If `cluster_spec` is specified, either passed + in to __init__() method or parsed from the + ["TF_CONFIG" environment + variable](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig), + variables and updates to those variables are assigned to parameter servers and + other operations are assigned to workers. If `cluster_spec` is not set, it + becomes local training where variables are assigned to local CPU or the only + GPU. When each worker has more than one GPU, operations will be replicated on + these GPUs. In both cases, operations are replicated but variables are not and + these workers share a common view for which paramater server a variable is + assigned to. + + This class assumes between-graph replication will be used and works on a graph + for a particular worker. + + It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any + operations which potentially can be replicated across towers (i.e. multiple + GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra + caution needs to be taken: + + 1) Always use @{tf.get_variable} instead of @{tf.Variable} which is not able + to refer to the same variable on different towers. + + 2) It is generally not recommended to open a device scope under the strategy's + scope. A device scope (i.e. calling @{tf.device}) will be merged with or + override the device for operations but will not change the device for + variables. + + 3) It is also not recommended to open a colocation scope (i.e. calling + @{tf.colocate_with}) under the strategy's scope. For colocating variables, + use `distribution.colocate_vars_with` instead. Colocation of ops will possibly + create conflicts of device assignement. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type=None, + task_id=None): + """Initiailizes this strategy. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + """ + super(ParameterServerStrategy, self).__init__() + self._num_gpus_per_worker = num_gpus_per_worker + if cluster_spec: + cluster_spec = _normalize_cluster_spec(cluster_spec) + self._cluster_spec = cluster_spec + + # We typically don't need to do all-reduce in this strategy. + self._cross_tower_ops = ( + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_LOCAL_CPU)) + + self._initialize_devices(num_gpus_per_worker, cluster_spec, task_type, + task_id) + + def _initialize_devices(self, num_gpus_per_worker, cluster_spec, task_type, + task_id): + """Initialize internal devices. + + It creates variable devices and compute devices. Variables and operations + will be assigned to them respectively. We have one compute device per tower. + The variable device is a device function or device string. The default + variable device assigns variables to parameter servers in a round-robin + fashion. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + + Raises: + ValueError: if the cluster_spec doesn't have ps jobs. + """ + self._task_type = task_type or "worker" + self._task_id = task_id or 0 + self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) + + # TODO(yuefengz): maybe clearer to split it into two classes, one for + # the distribuetd case and one for the local case, once we have the factory + # class/method. + + # Define compute devices which is a list of device strings and one for each + # tower. When there are GPUs, replicate operations on these GPUs. Otherwise, + # place operations on CPU. + if cluster_spec is None: + # Local mode. + if num_gpus_per_worker > 0: + self._compute_devices = list( + map("/device:GPU:{}".format, range(num_gpus_per_worker))) + else: + self._compute_devices = [_LOCAL_CPU] + else: + # Distributed mode. + if num_gpus_per_worker > 0: + self._compute_devices = [ + "%s/device:GPU:%d" % (self._worker_device, i) + for i in range(num_gpus_per_worker) + ] + else: + self._compute_devices = [self._worker_device] + + self._compute_devices = list( + map(device_util.resolve, self._compute_devices)) + self._canonical_compute_device_set = set(self._compute_devices) + + # Define variable device which is a device string in the local case and a + # device function in the distributed case. It is used to open a device scope + # where varibles are defined. + # The `_parameter_devices` is needed for the `parameter_devices` property + # and is a list of all variable devices. + if cluster_spec is None: + # Local mode. If there is only one GPU, put everything on that GPU. + # Otherwise, place variables on CPU. + if num_gpus_per_worker == 1: + assert len(list(self._compute_devices)) == 1 + self._variable_device = _LOCAL_GPU_0 + self._parameter_devices = [_LOCAL_GPU_0] + else: + self._variable_device = _LOCAL_CPU + self._parameter_devices = [_LOCAL_CPU] + else: + # Distributed mode. Place variables on ps jobs in a round-robin fashion. + # Note that devices returned from `replica_device_setter` are not + # canonical and therefore we don't canonicalize all variable devices to + # make them consistent. + # TODO(yuefengz): support passing a strategy object to control variable + # assignment. + # TODO(yuefengz): merge the logic of replica_device_setter into this + # class. + num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) + if num_ps_replicas == 0: + raise ValueError("The cluster spec needs to have `ps` jobs.") + self._variable_device = device_setter.replica_device_setter( + ps_tasks=num_ps_replicas, + worker_device=self._worker_device, + merge_devices=True, + cluster=cluster_spec) + + # Parameter devices are all tasks of the "ps" job. + self._parameter_devices = map("/job:ps/task:{}".format, + range(num_ps_replicas)) + + # Define the default device in cross-tower mode. In the distributed case, we + # set the default device to the corresponding worker to prevent these ops + # from being placed on other workers. + if cluster_spec is None: + self._default_device = None + else: + self._default_device = self._worker_device + + def distribute_dataset(self, dataset_fn): + """Distributes the dataset to each local GPU.""" + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), self._compute_devices, True) + + def _broadcast(self, tensor, destinations): + if not cross_tower_ops_lib.check_destinations(destinations): + destinations = self._compute_devices + return self._cross_tower_ops.broadcast(tensor, destinations) + + # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through + # this creator, such as "MutableHashTable". + def _create_variable(self, next_creator, *args, **kwargs): + if "colocate_with" in kwargs: + with ops.device(None): + with ops.colocate_with(kwargs["colocate_with"]): + return next_creator(*args, **kwargs) + + with ops.colocate_with(None, ignore_existing=True): + with ops.device(self._variable_device): + return next_creator(*args, **kwargs) + + def _call_for_each_tower(self, fn, *args, **kwargs): + # pylint: disable=protected-access + return mirrored_strategy._call_for_each_tower(self, fn, *args, **kwargs) + + def _verify_destinations_not_different_worker(self, destinations): + if destinations is None: + return + for d in cross_tower_ops_lib.get_devices_from(destinations): + d_spec = tf_device.DeviceSpec.from_string(d) + if d_spec.job == self._task_type and d_spec.task != self._task_id: + raise ValueError( + "Cannot reduce to another worker: %r, current worker is %r" % + (d, self._worker_device)) + + def _reduce(self, aggregation, value, destinations): + self._verify_destinations_not_different_worker(destinations) + if not isinstance(value, values.DistributedValues): + # pylint: disable=protected-access + return mirrored_strategy._reduce_non_distributed_value( + self, aggregation, value, destinations) + + return self._cross_tower_ops.reduce( + aggregation, value, destinations=destinations) + + def _batch_reduce(self, aggregation, value_destination_pairs): + for _, destinations in value_destination_pairs: + self._verify_destinations_not_different_worker(destinations) + return self._cross_tower_ops.batch_reduce(aggregation, + value_destination_pairs) + + def _select_single_value(self, structured): + """Select any single values in `structured`.""" + + def _select_fn(x): # pylint: disable=g-missing-docstring + if isinstance(x, values.Mirrored): + if len(x.devices) == 1: + return list(x._index.values())[0] # pylint: disable=protected-access + else: + raise ValueError( + "You cannot update variable with a Mirrored object with multiple " + "components %r when using ParameterServerStrategy. You must " + "specify a single value or a Mirrored with a single value." % x) + elif isinstance(x, values.PerDevice): + raise ValueError( + "You cannot update variable with a PerDevice object %r when using " + "ParameterServerStrategy. You must specify a single value or a " + "Mirrored with a single value" % x) + else: + return x + + return nest.map_structure(_select_fn, structured) + + def _update(self, var, fn, *args, **kwargs): + if not isinstance(var, resource_variable_ops.ResourceVariable): + raise ValueError( + "You can not update `var` %r. It must be a Variable." % var) + with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): + return fn(var, *self._select_single_value(args), + **self._select_single_value(kwargs)) + + # TODO(yuefengz): does it need to call _select_single_value? + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + with ops.device( + colocate_with.device), distribute_lib.UpdateContext(colocate_with): + return fn(*args, **kwargs) + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + if set(val.devices) == self._canonical_compute_device_set: + return [val.get(device=d) for d in self._compute_devices] + return [val.get(device=d) for d in sorted(val.devices)] + return [val] + + def value_container(self, val): + return values.value_container(val) + + def read_var(self, var): + # No need to distinguish between normal variables and tower-local variables. + return array_ops.identity(var) + + def configure(self, session_config=None): + del session_config + + # Use TF_CONFIG to get the cluster spec and the current job. + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", "worker") + task_id = int(task_env.get("index", "0")) + else: + task_type = "worker" + task_id = None + + # Set the devices if cluster_spec is defined in TF_CONFIG but not passed in + # the constructor. + if not self._cluster_spec and cluster_spec: + self._cluster_spec = cluster_spec + self._initialize_devices(self._num_gpus_per_worker, cluster_spec, + task_type, task_id) + + @property + def num_towers(self): + return len(self._compute_devices) + + @property + def worker_devices(self): + # Make a copy to prevent users from accidentally mutating our copy. + return list(self._compute_devices) + + @property + def parameter_devices(self): + return list(self._parameter_devices) + + def non_slot_devices(self, var_list): + return min(var_list, key=lambda x: x.name) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cf29c0ed91a14843ce15bf671dd363ca0f7073c0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -0,0 +1,430 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ParameterServerStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import threading +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib + + +class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=2) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ], + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + + def setUp(self): + self._result = 0 + self._lock = threading.Lock() + self._init_condition = threading.Condition() + self._init_reached = 0 + self._finish_condition = threading.Condition() + self._finish_reached = 0 + super(ParameterServerStrategyTest, self).setUp() + + def _get_test_objects(self, task_type, task_id, num_gpus): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if not task_type: + return distribution, '' + + tf_config = { + 'cluster': self._cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id + } + } + with self._lock: + # Accessing environment variables should be protected by locks because + # environment variables are shared by all threads. + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + distribution.configure() + return distribution, self._workers[task_id].target + + def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) + d, _ = self._get_test_objects(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(target=self._workers[0].target) as sess, \ + d.scope(): + + # Define a variable outside the call_for_each_tower scope. This is not + # recommended. + n = variable_scope.get_variable('n', initializer=10.0) + self.assertEqual(n.device, '/job:ps/task:0') + + def model_fn(): + if num_gpus == 0: + last_part_device = 'device:CPU:0' + else: + last_part_device = ( + 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + self.assertEqual(a.device, worker_device + '/' + last_part_device) + self.assertEqual(b.device, worker_device + '/' + last_part_device) + self.assertEqual(c.device, worker_device + '/' + last_part_device) + + # The device scope is ignored for variables but not for normal ops. + with ops.device('/job:worker/task:0'): + x = variable_scope.get_variable('x', initializer=10.0) + x_add = x.assign_add(c) + e = a + c + # The variable x is on the task 1 since the device_function has been + # called once before the model_fn. + self.assertEqual(x.device, '/job:ps/task:1') + self.assertEqual(x_add.device, x.device) + self.assertEqual(e.device, + '/job:worker/replica:0/task:0/%s' % last_part_device) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x): + y = variable_scope.get_variable('y', initializer=20.0) + y_add = y.assign_add(x_add) + self.assertEqual(y.device, '/job:ps/task:1') + self.assertEqual(y_add.device, y.device) + self.assertEqual(y.device, x.device) + + z = variable_scope.get_variable('z', initializer=10.0) + self.assertEqual(z.device, '/job:ps/task:0') + self.assertNotEqual(z.device, x.device) + + with ops.control_dependencies([y_add]): + z_add = z.assign_add(y) + with ops.control_dependencies([z_add]): + f = z + c + self.assertEqual(f.device, worker_device + '/' + last_part_device) + + # The device scope would merge with the default worker device. + with ops.device('/CPU:1'): + g = e + 1.0 + self.assertEqual(g.device, worker_device + '/device:CPU:1') + + # Ths ops.colocate_with will be ignored when defining a variale but not + # for a normal tensor. + with ops.colocate_with(x): + u = variable_scope.get_variable('u', initializer=30.0) + v = variable_scope.get_variable('v', initializer=30.0) + h = f + 1.0 + self.assertIn('/job:ps/', u.device) + self.assertIn('/job:ps/', v.device) + # u and v are on different parameter servers. + self.assertTrue(u.device != x.device or v.device != x.device) + self.assertTrue(u.device == x.device or v.device == x.device) + # Here h is not on one worker. Note h.device is canonical while x.device + # is not but. + self.assertIn('/job:ps/', h.device) + return y_add, z_add, f + + y, z, f = d.call_for_each_tower(model_fn) + self.assertNotEqual(y, None) + self.assertNotEqual(z, None) + self.assertNotEqual(f, None) + + if context.num_gpus() >= 1 and num_gpus <= 1: + variables.global_variables_initializer().run() + y_val, z_val, f_val = sess.run([y, z, f]) + self.assertEqual(y_val, 33.0) + self.assertEqual(z_val, 43.0) + self.assertEqual(f_val, 46.0) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributed(self, num_gpus): + self._test_device_assignment_distributed('worker', 1, num_gpus) + + def _test_device_assignment_local(self, + d, + compute_device='CPU', + variable_device='CPU', + num_gpus=0): + with ops.Graph().as_default(), \ + self.test_session(target=self._workers[0].target) as sess, \ + d.scope(): + + def model_fn(): + if 'CPU' in compute_device: + tower_compute_device = '/device:CPU:0' + else: + tower_compute_device = ( + '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + tower_compute_device = device_util.canonicalize(tower_compute_device) + + if 'CPU' in variable_device: + tower_variable_device = '/device:CPU:0' + else: + tower_variable_device = ( + '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + tower_variable_device = device_util.canonicalize(tower_variable_device) + + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + self.assertEqual(a.device, tower_compute_device) + self.assertEqual(b.device, tower_compute_device) + self.assertEqual(c.device, tower_compute_device) + + # The device scope is ignored for variables but not for normal ops. + with ops.device('/device:GPU:2'): + x = variable_scope.get_variable('x', initializer=10.0) + x_add = x.assign_add(c) + e = a + c + self.assertEqual( + device_util.canonicalize(x.device), tower_variable_device) + self.assertEqual(x_add.device, x.device) + self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x): + y = variable_scope.get_variable('y', initializer=20.0) + y_add = y.assign_add(x_add) + self.assertEqual( + device_util.canonicalize(y.device), tower_variable_device) + self.assertEqual(y_add.device, y.device) + self.assertEqual(y.device, x.device) + + z = variable_scope.get_variable('z', initializer=10.0) + self.assertEqual( + device_util.canonicalize(z.device), tower_variable_device) + + with ops.control_dependencies([y_add]): + z_add = z.assign_add(y) + with ops.control_dependencies([z_add]): + f = z + c + self.assertEqual(f.device, tower_compute_device) + + # The device scope would merge with the default worker device. + with ops.device('/CPU:1'): + g = e + 1.0 + self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1')) + + # Ths ops.colocate_with will be ignored when defining a variale but not + # for a normal tensor. + with ops.colocate_with(x): + u = variable_scope.get_variable('u', initializer=30.0) + h = f + 1.0 + self.assertEqual( + device_util.canonicalize(u.device), tower_variable_device) + self.assertEqual(device_util.canonicalize(x.device), h.device) + return y_add, z_add, f + + y, z, f = d.call_for_each_tower(model_fn) + self.assertNotEqual(y, None) + self.assertNotEqual(z, None) + self.assertNotEqual(f, None) + + if context.num_gpus() >= 1 and num_gpus <= 1: + variables.global_variables_initializer().run() + y_val, z_val, f_val = sess.run([y, z, f]) + self.assertEqual(y_val, 33.0) + self.assertEqual(z_val, 43.0) + self.assertEqual(f_val, 46.0) + + def testDeviceAssignmentLocalCPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=0) + self._test_device_assignment_local( + distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + + def testDeviceAssignmentLocalOneGPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=1) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + + def testDeviceAssignmentLocalTwoGPUs(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + + def _test_simple_increment(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + if hasattr(d, '_cluster_spec') and d._cluster_spec: + num_workers = len(d._cluster_spec.as_dict().get('worker', + ['dummy_worker'])) + else: + num_workers = 1 + with ops.Graph().as_default(), \ + self.test_session(target=master_target) as sess, \ + d.scope(): + + def model_fn(): + x = variable_scope.get_variable('x', initializer=10.0) + y = variable_scope.get_variable('y', initializer=20.0) + + x_add = x.assign_add(1.0, use_locking=True) + y_add = y.assign_add(1.0, use_locking=True) + + train_op = control_flow_ops.group([x_add, y_add]) + return x, y, train_op + + x, y, train_op = d.call_for_each_tower(model_fn) + train_op = d.group(d.unwrap(train_op)) + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + if task_id == 0: + variables.global_variables_initializer().run() + + # Workers waiting for chief worker's initializing variables. + self._init_condition.acquire() + self._init_reached += 1 + while self._init_reached != num_workers: + self._init_condition.wait() + self._init_condition.notify_all() + self._init_condition.release() + + sess.run(train_op) + + # Wait for other workers to finish training. + self._finish_condition.acquire() + self._finish_reached += 1 + while self._finish_reached != num_workers: + self._finish_condition.wait() + self._finish_condition.notify_all() + self._finish_condition.release() + + x_val, y_val = sess.run([x, y]) + self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers) + self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers) + return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and + y_val == 20.0 + 1.0 * num_workers * d.num_towers) + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + if task_id == 0: + variables.global_variables_initializer().run() + + # Workers waiting for chief worker's initializing variables. + self._init_condition.acquire() + self._init_reached += 1 + while self._init_reached != 3: + self._init_condition.wait() + self._init_condition.notify_all() + self._init_condition.release() + + for i in range(10): + b, a = sess.run((before_out, after_out)) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + def testSimpleBetweenGraph(self): + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, 0) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testLocalSimpleIncrement(self, num_gpus): + self._test_simple_increment(None, 0, num_gpus) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index bc53898539d76320e331784f9a717be9491365e1..83af37fc8175d56c8c4b3c75c63862fd07131184 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,40 +21,72 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import tpu +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib +from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.training import device_util +from tensorflow.python.training import server_lib from tensorflow.python.util import nest +def get_tpu_system_metadata(tpu_cluster_resolver): + """Retrieves TPU system metadata given a TPUClusterResolver.""" + master = tpu_cluster_resolver.master() + + # pylint: disable=protected-access + cluster_def = (tpu_cluster_resolver.cluster_spec() + or server_lib.ClusterSpec({})).as_cluster_def() + tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata( + master, + cluster_def=cluster_def, + query_topology=True)) + + return tpu_system_metadata + + class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, num_cores_per_host=2): + def __init__(self, tpu_cluster_resolver): + """Initializes the TPUStrategy object. + + Args: + tpu_cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, + which provides information about the TPU cluster. + """ # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. - super(TPUStrategy, self).__init__('/cpu:0') - # TODO(isaprykin): Auto-detect number of cores and hosts. - self._num_cores_per_host = num_cores_per_host + super(TPUStrategy, self).__init__('/device:CPU:0') + + self._tpu_cluster_resolver = tpu_cluster_resolver + self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) + # TODO(priyag): This should not be hardcoded here. - self._host = '/task:0/device:CPU:0' + self._host = '/device:CPU:0' def distribute_dataset(self, dataset_fn): # TODO(priyag): Perhaps distribute across cores here. return self._call_dataset_fn(dataset_fn) - # TODO(priyag): Deal with OutOfRange errors. + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): - # Enqueue ops + shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( @@ -68,7 +100,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): control_deps = [] sharded_inputs = [] with ops.device(self._host): - for _ in range(self._num_cores_per_host): + for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) @@ -93,58 +125,117 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): [constant_op.constant(0)], parallel_iterations=1) - # Dequeue ops def dequeue_fn(): - dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) + dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: - initial_loop_values = [] - ctx = values.MultiStepContext(initial_loop_values) + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + ctx = values.MultiStepContext() def run_fn(*args, **kwargs): del args, kwargs fn_result = fn(ctx, dequeue_fn()) - if ctx.last_step_outputs is None: - ctx.last_step_outputs = [] - with ops.control_dependencies([fn_result]): - return array_ops.identity(ctx.last_step_outputs) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + if flat_last_step_outputs: + with ops.control_dependencies([fn_result]): + return [array_ops.identity(f) for f in flat_last_step_outputs] + else: + return fn_result - # Repeat # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): - return tpu.repeat(iterations, run_fn, [initial_loop_values]) - - # Re-write and distribute computation. - # TODO(sourabhbajaj): Convert the output to PerDevice variable and - # implement support for that in reduce. - last_step_tensor_outputs = tpu.batch_parallel( - iterate_on_tpu, [], num_shards=self._num_cores_per_host) - - # Take index [0] of last_step_tensor_outputs as we wrapped - # initial_loop_values in a list in the `repeat` call. - return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops), - last_step_tensor_outputs[0], ctx) + return training_loop.repeat(iterations, run_fn, initial_loop_values) + + replicate_inputs = [[]] * self.num_towers + replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [x for x in replicate_outputs + if not isinstance(x, ops.Operation)] + + # Outputs are currently of the structure (grouped by device) + # [[output0_device0, output1_device0, output2_device0], + # [output0_device1, output1_device1, output2_device1]] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] + + # Convert replicate_outputs to the original dict structure of + # last_step_outputs. + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been aggregated, take the first value + # from the list as each value should be the same. Else return the full + # list of values. + if aggregation is not variables_lib.VariableAggregation.NONE: + # TODO(priyag): Should this return the element or a list with 1 element + last_step_tensor_outputs_dict[name] = output[0] + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + + return ctx def _call_for_each_tower(self, fn, *args, **kwargs): kwargs.pop('run_concurrently', None) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access return fn(*args, **kwargs) - def get_initialization_ops(self): - return [tpu.initialize_system()] + def initialize(self): + if context.executing_eagerly(): + # TODO(priyag): Add appopriate call here when eager is supported for TPUs. + raise NotImplementedError('Eager mode not supported in TPUStrategy.') + else: + return [tpu.initialize_system()] - def get_finalize_ops(self): - return [tpu.shutdown_system()] + def finalize(self): + if context.executing_eagerly(): + # TODO(priyag): Add appopriate call here when eager is supported for TPUs. + raise NotImplementedError('Eager mode not supported in TPUStrategy.') + else: + return [tpu.shutdown_system()] def _reduce(self, aggregation, value, destinations): - del destinations # TPU is graph mode only. Rely on implicit Send/Recv. + graph = ops.get_default_graph() + cf_context = graph._get_control_flow_context() # pylint: disable=protected-access + # If we're inside the ReplicateContext, reduction should be done using + # CrossReplicaSum while outside we can directly use an add_n op. + while cf_context: + if isinstance(cf_context, tpu.TPUReplicateContext): + if aggregation == vs.VariableAggregation.MEAN: + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self.num_towers) + return tpu_ops.cross_replica_sum(value) + cf_context = cf_context.outer_context + + # Validate that the destination is same as the host device + # Note we don't do this when in replicate context as the reduction is + # performed on the TPU device itself. + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + assert device_util.canonicalize(devices[0]) == device_util.canonicalize( + self._host) + else: + raise ValueError('Multiple devices are not supported for TPUStrategy') + + output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: - # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self._num_cores_per_host) - return tpu_ops.cross_replica_sum(value) + return output * (1. / len(value)) + return output + + def _unwrap(self, value): + if isinstance(value, list): + return value + return [value] @property def num_towers(self): - return self._num_cores_per_host + return self._tpu_metadata.num_of_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 47dcf679c2a6280d4be523b7fb04f0d2ba5855e8..5fd4c9de696b715c3fb9b8a6ca64923b413a32e9 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver @@ -210,6 +211,11 @@ class DistributedVariable(DistributedDelegate): # without this it will use `__getattr__` which will delegate to a component # variable. self._keras_initialized = False + # Typically, a `DistributedVariable`'s initializer is composed of the + # initializers of the components variables. However, in some cases, such as + # when restoring from a checkpoint, we may set the _initializer_op + # property on the entire `DistributedVariable`. + self._initializer_op = None super(DistributedVariable, self).__init__(index) def is_initialized(self, name=None): @@ -239,9 +245,14 @@ class DistributedVariable(DistributedDelegate): @property def initializer(self): - # return grouped ops of all the var initializations of component values of - # the mirrored variable - return control_flow_ops.group([v.initializer for v in self._index.values()]) + if self._initializer_op: + init_op = self._initializer_op + else: + # return grouped ops of all the var initializations of component values of + # the mirrored variable + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op @property def graph(self): @@ -284,6 +295,9 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op + def read_value(self): + return distribute_lib.get_distribution_strategy().read_var(self) + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass @@ -921,64 +935,120 @@ class MultiStepContext(object): This context object is useful when running multiple steps at a time using the `run_steps_on_dataset` API. For e.g. it allows the user's step function to - specify which outputs to emit at what frequency. Currently it only supports - capturing output from the last step, but will soon be augmented to support - other use cases such as output each N steps. + specify which outputs to emit at what frequency. Currently it supports + capturing output from the last step, as well as capturing non tensor outputs. + In the future it will be augmented to support other use cases such as output + each N steps. """ - def __init__(self, initial_loop_values=None): + def __init__(self): """Initializes an output context. - Args: - initial_loop_values: Initial values passed to the run steps - while loop. The only purpose is to verify the shapes and types - when the actual output is set. This will be removed once we - automatically infer the output shapes and types (and do not need to - check for user error in specifying them manually). Returns: A context object. """ - self._last_step_outputs = None - self._non_tensor_outputs = None - self._initial_loop_values = initial_loop_values + self._last_step_outputs = {} + self._last_step_outputs_aggregations = {} + self._non_tensor_outputs = {} @property def last_step_outputs(self): - """Return the last step's outputs.""" + """A dictionary consisting of outputs to be captured on last step. + + Keys in the dictionary are names of tensors to be captured, as specified + when `set_last_step_output` is called. + Values in the dictionary are the tensors themselves. If + `set_last_step_output` was called with an `aggregation` for this output, + then the value is the aggregated value. + + Returns: + A dictionary with last step outputs. + """ return self._last_step_outputs - @last_step_outputs.setter - def last_step_outputs(self, outputs): - """Set the last step's outputs.""" - self._verify_structure_shapes_types(outputs, self._initial_loop_values) + def _set_last_step_outputs(self, outputs): + """Replace the entire dictionary of last step outputs.""" + if not isinstance(outputs, dict): + raise ValueError("Need a dictionary to set last_step_outputs.") self._last_step_outputs = outputs + def set_last_step_output(self, name, output, + aggregation=variables_lib.VariableAggregation.NONE): + """Set `output` with `name` to be outputted from the last step. + + Args: + name: String, name to identify the output. Doesn't need to match tensor + name. + output: The tensors that should be outputted with `name`. See below for + actual types supported. + aggregation: Aggregation method to use to aggregate outputs from multiple + towers. Required if `set_last_step_output` is called in a tower context. + Optional in cross_tower_context. + When present, the outputs from all the towers are aggregated using the + current distribution strategy's `reduce` method. Hence, the type of + `output` must be what's supported by the corresponding `reduce` method. + For e.g. if using MirroredStrategy and aggregation is set, output + must be a `PerDevice` value. + The aggregation method is also recorded in a dictionary + `_last_step_outputs_aggregations` for later interpreting of the + outputs as already reduced or not. + + """ + if distribute_lib.get_cross_tower_context(): + self._last_step_outputs_aggregations[name] = aggregation + if aggregation is variables_lib.VariableAggregation.NONE: + self._last_step_outputs[name] = output + else: + distribution = distribute_lib.get_distribution_strategy() + self._last_step_outputs[name] = distribution.reduce( + aggregation, output, destinations="/device:CPU:0") + else: + assert aggregation is not variables_lib.VariableAggregation.NONE + def merge_fn(distribution, value): + self._last_step_outputs[name] = distribution.reduce( + aggregation, value, destinations="/device:CPU:0") + # Setting this inside the `merge_fn` because all towers share the same + # context object, so it's more robust to set it only once (even if all + # the towers are trying to set the same value). + self._last_step_outputs_aggregations[name] = aggregation + distribute_lib.get_tower_context().merge_call(merge_fn, output) + @property def non_tensor_outputs(self): - """Return the non tensor outputs.""" + """A dictionary consisting of any non tensor outputs to be captured.""" return self._non_tensor_outputs - @non_tensor_outputs.setter - def non_tensor_outputs(self, outputs): - """Set any non tensor outputs.""" - self._non_tensor_outputs = outputs - - def _verify_structure_shapes_types(self, left, right): - """Verify that the structure, shapes and types of left are same as right.""" - nest.assert_same_structure(left, right) - flat_left = nest.flatten(left) - flat_right = nest.flatten(right) - assert len(flat_left) == len(flat_right), ( - "Length of left {} and right {} should be same.". - format(len(flat_left), len(flat_right))) - - for o, i in zip(flat_left, flat_right): - # TODO(priyag): Add checks for other types like IndexedSlices. - if isinstance(o, ops.Tensor): - assert isinstance(i, ops.Tensor) - assert o.shape == i.shape, ( - "Shape {} of left {} doesn't match shape {} of right {}.". - format(o.shape, o, i.shape, i)) - assert o.dtype == i.dtype, ( - "Dtype {} of left {} doesn't match dtype {} of right {}.". - format(o.dtype, o, i.dtype, i)) + def set_non_tensor_output(self, name, output): + """Set `output` with `name` to be captured as a non tensor output.""" + if distribute_lib.get_cross_tower_context(): + self._non_tensor_outputs[name] = output + else: + def merge_fn(distribution, value): + # NOTE(priyag): For non tensor outputs, we simply return all the values + # in a list as aggregation doesn't make sense on non tensors. + self._non_tensor_outputs[name] = distribution.unwrap(value) + distribute_lib.get_tower_context().merge_call(merge_fn, output) + + +def value_container(val): + """Returns the container that this per-device `value` belongs to. + + Args: + val: A value returned by `call_for_each_tower()` or a variable + created in `scope()`. + + Returns: + A container that `value` belongs to. + If value does not belong to any container (including the case of + container having been destroyed), returns the value itself. + """ + # pylint: disable=protected-access + if (hasattr(val, "_distributed_container") and + # DistributedVariable has _distributed_container defined + # but we don't want to return it. + not isinstance(val, DistributedVariable)): + container = val._distributed_container() + # pylint: disable=protected-access + if container is not None: + return container + return val diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index ad00d1734dd14ed846522a33d888a5387cb25cc6..a8d0d493abcd7de540799f6b94c3cdb9ce9dafae 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -124,7 +124,7 @@ cuda_py_test( cuda_py_test( name = "conditional_distribution_test", - size = "small", + size = "medium", srcs = [ "python/kernel_tests/conditional_distribution_test.py", "python/kernel_tests/distribution_test.py", diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index ef3bdfa75fcaa8df17db1238ceadadf788601356..18a0f754e6e618f240db109f593a80dec57e200b 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -326,6 +326,21 @@ class QuantizedDistribution(distributions.Distribution): graph_parents=graph_parents, name=name) + @property + def distribution(self): + """Base distribution, p(x).""" + return self._dist + + @property + def low(self): + """Lowest value that quantization returns.""" + return self._low + + @property + def high(self): + """Highest value that quantization returns.""" + return self._high + def _batch_shape_tensor(self): return self.distribution.batch_shape_tensor() @@ -569,8 +584,3 @@ class QuantizedDistribution(distributions.Distribution): dependencies = [distribution_util.assert_integer_form( value, message="value has non-integer components.")] return control_flow_ops.with_dependencies(dependencies, value) - - @property - def distribution(self): - """Base distribution, p(x).""" - return self._dist diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index f5aaa5cf34abde3ea4d25de1ecf3adaef3f2a770..aa680a92be64cf0f099acd335369f2a1610c5953 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -134,7 +134,7 @@ def auto_correlation( x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At - # the moment is is necessary so that all FFT implementations work. + # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = math_ops.cast(x_len, np.float64) @@ -198,7 +198,7 @@ def auto_correlation( # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to - # divide by by to make this an unbiased estimate of the expectation + # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = math_ops.cast(x_len, dtype.real_dtype) max_lags = math_ops.cast(max_lags, dtype.real_dtype) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index e31dbbe80f9634e8e45ec91bf395eab82942c8ce..16844e0d6885919118adcd3f5a7777eec57b1e9c 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -22,12 +22,9 @@ from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.saver import BaseSaverBuilder -class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): +class Iterator(iterator_ops.EagerIterator): """An iterator producing tf.Tensor objects from a tf.data.Dataset. NOTE: Unlike the iterator created by the @@ -82,30 +79,3 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): return super(Iterator, self)._next_internal() - - # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset - # attributes(potential). - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject for saving/restoring iterator state.""" - - def __init__(self, iterator_resource, name): - serialized_iterator = gen_dataset_ops.serialize_iterator( - iterator_resource) - specs = [ - BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") - ] - # pylint: disable=protected-access - super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - - def restore(self, restored_tensors, restored_shapes): - with ops.colocate_with(self.op): - return gen_dataset_ops.deserialize_iterator(self.op, - restored_tensors[0]) - - def _gather_saveables_for_checkpoint(self): - - def _saveable_factory(name): - return self._Saveable(self._resource, name) - - return {"ITERATOR": _saveable_factory} diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index acc605247faffcf7ba83891dacdab13fc8c8574a..a753d77580758af9de8410de4a08f7ea278c4c79 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -306,6 +307,19 @@ class IteratorTest(test.TestCase): checkpoint.restore(save_path) self.assertEqual(2, iterator.get_next().numpy()) + def testRestoreInReconstructedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(10) + for i in range(5): + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) + for j in range(2): + self.assertEqual(i * 2 + j, iterator.get_next().numpy()) + checkpoint.save(file_prefix=checkpoint_prefix) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 12155a459c29c353c57679c407e7dda25047a35c..6f02c90368d966b8cf8d0dee09f9d2a5013c90c1 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -15,8 +15,6 @@ py_library( "//tensorflow/contrib/eager/python/examples/revnet:config", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", - "//tensorflow/contrib/eager/python/examples/sagan", - "//tensorflow/contrib/eager/python/examples/sagan:config", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py index bd0057fb1a0175a805a0f7a1e4dcaa2bdc3c435a..4b3cb624bc947a1d1956eff6accb6d4da3bf3b87 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -128,8 +128,10 @@ class DensenetBenchmark(tf.test.Benchmark): weight_decay=1e-4, dropout_rate=0, pool_initial=True, include_top=True) logits = model(images, training=True) - loss = tf.losses.softmax_cross_entropy( + cross_ent = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) + regularization = tf.add_n(model.losses) + loss = cross_ent + regularization optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) train_op = optimizer.minimize(loss) diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py index 4f19711fb87d6b5558302fd69104aca7e2cf403e..0736ed02b7437240e5da4dd529ad9ba9a5a15042 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py @@ -98,12 +98,52 @@ class DensenetTest(tf.test.TestCase): output_shape = model(rand_input).shape self.assertEqual(output_shape, (batch_size, output_classes)) + def test_regularization(self): + if tf.test.is_gpu_available(): + rand_input = tf.random_uniform((10, 3, 32, 32)) + data_format = 'channels_first' + else: + rand_input = tf.random_uniform((10, 32, 32, 3)) + data_format = 'channels_last' + weight_decay = 1e-4 + + conv = tf.keras.layers.Conv2D( + 3, (3, 3), + padding='same', + use_bias=False, + data_format=data_format, + kernel_regularizer=tf.keras.regularizers.l2(weight_decay)) + optimizer = tf.train.GradientDescentOptimizer(0.1) + conv(rand_input) # Initialize the variables in the layer + + def compute_true_l2(vs, wd): + return tf.reduce_sum(tf.square(vs)) * wd + + true_l2 = compute_true_l2(conv.variables, weight_decay) + keras_l2 = tf.add_n(conv.losses) + self.assertAllClose(true_l2, keras_l2) + + with tf.GradientTape() as tape_true, tf.GradientTape() as tape_keras: + loss = tf.reduce_sum(conv(rand_input)) + loss_with_true_l2 = loss + compute_true_l2(conv.variables, weight_decay) + loss_with_keras_l2 = loss + tf.add_n(conv.losses) + + true_grads = tape_true.gradient(loss_with_true_l2, conv.variables) + keras_grads = tape_keras.gradient(loss_with_keras_l2, conv.variables) + self.assertAllClose(true_grads, keras_grads) + + optimizer.apply_gradients(zip(keras_grads, conv.variables)) + keras_l2_after_update = tf.add_n(conv.losses) + self.assertNotAllClose(keras_l2, keras_l2_after_update) + def compute_gradients(model, images, labels): with tf.GradientTape() as tape: logits = model(images, training=True) - loss = tf.losses.softmax_cross_entropy( + cross_ent = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) + regularization = tf.add_n(model.losses) + loss = cross_ent + regularization tf.contrib.summary.scalar(name='loss', tensor=loss) return tape.gradient(loss, model.variables) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f91ae374488b735549e5c8fc03a309b14f8ba4be --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -0,0 +1,634 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Convolutional VAE: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "P-JuIu2N_SQf" + }, + "outputs": [], + "source": [ + "# to generate gifs\n", + "!pip install imageio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable Eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "tfe = tf.contrib.eager\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import glob\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "import imageio\n", + "from IPython import display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the MNIST dataset\n", + "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "a4fYMGxGhrna" + }, + "outputs": [], + "source": [ + "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NFC2ghIdiZYE" + }, + "outputs": [], + "source": [ + "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", + "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n", + "\n", + "# Normalizing the images to the range of [0., 1.]\n", + "train_images /= 255.\n", + "test_images /= 255.\n", + "\n", + "# Binarization\n", + "train_images[train_images \u003e= .5] = 1.\n", + "train_images[train_images \u003c .5] = 0.\n", + "test_images[test_images \u003e= .5] = 1.\n", + "test_images[test_images \u003c .5] = 0." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "S4PIDhoDLbsZ" + }, + "outputs": [], + "source": [ + "TRAIN_BUF = 60000\n", + "BATCH_SIZE = 100\n", + "\n", + "TEST_BUF = 10000" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use *tf.data* to create batches and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-yKCCQOoJ7cn" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n", + "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Wire up the generative and inference network with *tf.keras.Sequential*\n", + "\n", + "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n", + "\n", + "### Generative Network\n", + "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n", + "\n", + "### Inference Network\n", + "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n", + "\n", + "### Reparameterization Trick\n", + "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n", + "\n", + "### Network architecture\n", + "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VGLbvBEmjK0a" + }, + "outputs": [], + "source": [ + "class CVAE(tf.keras.Model):\n", + " def __init__(self, latent_dim):\n", + " super(CVAE, self).__init__()\n", + " self.latent_dim = latent_dim\n", + " self.inference_net = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n", + " tf.keras.layers.Conv2D(\n", + " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", + " tf.keras.layers.Conv2D(\n", + " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", + " tf.keras.layers.Flatten(),\n", + " # No activation\n", + " tf.keras.layers.Dense(latent_dim + latent_dim),\n", + " ]\n", + " )\n", + "\n", + " self.generative_net = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n", + " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n", + " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=64,\n", + " kernel_size=3,\n", + " strides=(2, 2),\n", + " padding=\"SAME\",\n", + " activation=tf.nn.relu),\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=32,\n", + " kernel_size=3,\n", + " strides=(2, 2),\n", + " padding=\"SAME\",\n", + " activation=tf.nn.relu),\n", + " # No activation\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n", + " ]\n", + " )\n", + "\n", + " def sample(self, eps=None):\n", + " if eps is None:\n", + " eps = tf.random_normal(shape=(100, self.latent_dim))\n", + " return self.decode(eps, apply_sigmoid=True)\n", + "\n", + " def encode(self, x):\n", + " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n", + " return mean, logvar\n", + "\n", + " def reparameterize(self, mean, logvar):\n", + " eps = tf.random_normal(shape=mean.shape)\n", + " return eps * tf.exp(logvar * .5) + mean\n", + "\n", + " def decode(self, z, apply_sigmoid=False):\n", + " logits = self.generative_net(z)\n", + " if apply_sigmoid:\n", + " probs = tf.sigmoid(logits)\n", + " return probs\n", + "\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss function and the optimizer\n", + "\n", + "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n", + "\n", + "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n", + "\n", + "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n", + "\n", + "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n", + "where $z$ is sampled from $q(z|x)$.\n", + "\n", + "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "def log_normal_pdf(sample, mean, logvar, raxis=1):\n", + " log2pi = tf.log(2. * np.pi)\n", + " return tf.reduce_sum(\n", + " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n", + " axis=raxis)\n", + "\n", + "def compute_loss(model, x):\n", + " mean, logvar = model.encode(x)\n", + " z = model.reparameterize(mean, logvar)\n", + " x_logit = model.decode(z)\n", + "\n", + " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n", + " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n", + " logpz = log_normal_pdf(z, 0., 0.)\n", + " logqz_x = log_normal_pdf(z, mean, logvar)\n", + " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n", + "\n", + "def compute_gradients(model, x):\n", + " with tf.GradientTape() as tape:\n", + " loss = compute_loss(model, x)\n", + " return tape.gradient(loss, model.trainable_variables), loss\n", + "\n", + "optimizer = tf.train.AdamOptimizer(1e-4)\n", + "def apply_gradients(optimizer, gradients, variables, global_step=None):\n", + " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n", + "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n", + "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n", + "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, it is time to generate some images\n", + "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n", + "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n", + "* Here we plot the probabilities of Bernoulli distributions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "epochs = 100\n", + "latent_dim = 50\n", + "num_examples_to_generate = 100\n", + "\n", + "# keeping the random vector constant for generation (prediction) so\n", + "# it will be easier to see the improvement.\n", + "random_vector_for_generation = tf.random_normal(\n", + " shape=[num_examples_to_generate, latent_dim])\n", + "model = CVAE(latent_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_and_save_images(model, epoch, test_input):\n", + " predictions = model.sample(test_input)\n", + " fig = plt.figure(figsize=(10,10))\n", + "\n", + " for i in range(predictions.shape[0]):\n", + " plt.subplot(10, 10, i+1)\n", + " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n", + " plt.axis('off')\n", + "\n", + " # tight_layout minimizes the overlap between 2 sub-plots\n", + " plt.tight_layout()\n", + " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "generate_and_save_images(model, 0, random_vector_for_generation)\n", + "\n", + "for epoch in range(1, epochs + 1):\n", + " start_time = time.time()\n", + " for train_x in train_dataset:\n", + " gradients, loss = compute_gradients(model, train_x)\n", + " apply_gradients(optimizer, gradients, model.trainable_variables)\n", + " end_time = time.time()\n", + "\n", + " if epoch % 5 == 0:\n", + " loss = tfe.metrics.Mean()\n", + " for test_x in test_dataset.make_one_shot_iterator():\n", + " loss(compute_loss(model, test_x))\n", + " elbo = -loss.result()\n", + " display.clear_output(wait=False)\n", + " print('Epoch: {}, Test set ELBO: {}, '\n", + " 'time elapse for current epoch {}'.format(epoch,\n", + " elbo,\n", + " end_time - start_time))\n", + " generate_and_save_images(\n", + " model, epoch, random_vector_for_generation)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P4M_vIbUi7c0" + }, + "source": [ + "### Display an image using the epoch number" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WfO5wCdclHGL" + }, + "outputs": [], + "source": [ + "def display_image(epoch_no):\n", + " plt.figure(figsize=(15,15))\n", + " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n", + " plt.axis('off')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "5x3q9_Oe5q0A" + }, + "outputs": [], + "source": [ + "display_image(epochs) # Display images" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NywiH3nL8guF" + }, + "source": [ + "### Generate a GIF of all the saved images." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IGKQgENQ8lEI" + }, + "outputs": [], + "source": [ + "with imageio.get_writer('cvae.gif', mode='I') as writer:\n", + " filenames = glob.glob('image*.png')\n", + " filenames = sorted(filenames)\n", + " for filename in filenames:\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " # this is a hack to display the gif inside the notebook\n", + " os.system('mv cvae.gif cvae.gif.png')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "uV0yiKpzNP1b" + }, + "outputs": [], + "source": [ + "display.Image(filename=\"cvae.gif.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "JGZBy7glUU2O" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "last_runtime": { + "build_target": "//learning/brain/python/client:colab_notebook", + "kind": "private" + }, + "name": "cvae.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 54cc4dc5daa6b9da8d636cb5ca5d2e2d1cc1ff37..44ff43a1112e771eb6c91c398286a003e17632e0 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -29,7 +29,7 @@ "source": [ "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). To do so, we use Deep Convolutional Generative Adverserial Networks ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)).\n", "\n", - "This model takes about 40 seconds per epoch to train on a single Tesla K80 on Colab, as of July 2018.\n", + "This model takes about ~30 seconds per epoch (using tf.contrib.eager.defun to create graph functions) to train on a single Tesla K80 on Colab, as of July 2018.\n", "\n", "Below is the output generated after training the generator and discriminator models for 150 epochs.\n", "\n", @@ -203,7 +203,7 @@ "## Write the generator and discriminator models\n", "\n", "* **Generator** \n", - " * It is responsible for **creating the convincing images good enough to fool the discriminator**.\n", + " * It is responsible for **creating convincing images that are good enough to fool the discriminator**.\n", " * It consists of Conv2DTranspose (Upsampling) layers. We start with a fully connected layer and upsample the image 2 times so as to reach the desired image size (mnist image size) which is (28, 28, 1). \n", " * We use **leaky relu** activation except for the **last layer** which uses **tanh** activation.\n", " \n", @@ -314,6 +314,26 @@ "discriminator = Discriminator()" ] }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "k1HpMSLImuRi" + }, + "outputs": [], + "source": [ + "# Defun gives 10 secs/epoch performance boost\n", + "generator.call = tf.contrib.eager.defun(generator.call)\n", + "discriminator.call = tf.contrib.eager.defun(discriminator.call)" + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index 6be09f98dff6627d9bdcc8056eed14e2a621be4b..b173f856c641b4d7dca96adda113f904c97a25a7 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -1,31 +1,11 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "text_generation.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "accelerator": "GPU" - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "hcD2nPQvPOFM", - "colab_type": "text" + "colab_type": "text", + "id": "hcD2nPQvPOFM" }, - "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors.\n", "\n", @@ -33,19 +13,19 @@ "\n", "# Text Generation using a RNN\n", "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on Github
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on Github\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { + "cell_type": "markdown", "metadata": { - "id": "BwpJ5IffzRG6", - "colab_type": "text" + "colab_type": "text", + "id": "BwpJ5IffzRG6" }, - "cell_type": "markdown", "source": [ "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", "\n", @@ -102,58 +82,60 @@ ] }, { + "cell_type": "markdown", "metadata": { - "id": "R3p22DBDsaCA", - "colab_type": "text" + "colab_type": "text", + "id": "R3p22DBDsaCA" }, - "cell_type": "markdown", "source": [ "## Install unidecode library\n", "A helpful library to convert unicode to ASCII." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "wZ6LOM12wKGH", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "wZ6LOM12wKGH" }, - "cell_type": "code", + "outputs": [], "source": [ "!pip install unidecode" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "WGyKZj3bzf9p", - "colab_type": "text" + "colab_type": "text", + "id": "WGyKZj3bzf9p" }, - "cell_type": "markdown", "source": [ "## Import tensorflow and enable eager execution." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "yG_n40gFzf9s", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "yG_n40gFzf9s" }, - "cell_type": "code", + "outputs": [], "source": [ - "# Import TensorFlow >= 1.9 and enable eager execution\n", + "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", "import tensorflow as tf\n", "\n", "# Note: Once you enable eager execution, it cannot be disabled. \n", @@ -164,16 +146,14 @@ "import random\n", "import unidecode\n", "import time" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "EHDoRoc5PKWz", - "colab_type": "text" + "colab_type": "text", + "id": "EHDoRoc5PKWz" }, - "cell_type": "markdown", "source": [ "## Download the dataset\n", "\n", @@ -182,76 +162,78 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "pD_55cOxLkAb", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "pD_55cOxLkAb" }, - "cell_type": "code", + "outputs": [], "source": [ - "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/yashkatariya/shakespeare.txt')" - ], - "execution_count": 0, - "outputs": [] + "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "UHjdCjDuSvX_", - "colab_type": "text" + "colab_type": "text", + "id": "UHjdCjDuSvX_" }, - "cell_type": "markdown", "source": [ "## Read the dataset\n", "\n" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "-E5JvY3wzf94", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "-E5JvY3wzf94" }, - "cell_type": "code", + "outputs": [], "source": [ "text = unidecode.unidecode(open(path_to_file).read())\n", "# length of text is the number of characters in it\n", "print (len(text))" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "Il9ww98izf-D", - "colab_type": "text" + "colab_type": "text", + "id": "Il9ww98izf-D" }, - "cell_type": "markdown", "source": [ "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "IalZLbvOzf-F", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "IalZLbvOzf-F" }, - "cell_type": "code", + "outputs": [], "source": [ "# unique contains all the unique characters in the file\n", "unique = sorted(set(text))\n", @@ -259,22 +241,22 @@ "# creating a mapping from unique characters to indices\n", "char2idx = {u:i for i, u in enumerate(unique)}\n", "idx2char = {i:u for i, u in enumerate(unique)}" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "1v_qUYfAzf-I", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "1v_qUYfAzf-I" }, - "cell_type": "code", + "outputs": [], "source": [ "# setting the maximum length sentence we want for a single input in characters\n", "max_length = 100\n", @@ -293,16 +275,14 @@ "\n", "# buffer size to shuffle our dataset\n", "BUFFER_SIZE = 10000" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "LFjSVAlWzf-N", - "colab_type": "text" + "colab_type": "text", + "id": "LFjSVAlWzf-N" }, - "cell_type": "markdown", "source": [ "## Creating the input and output tensors\n", "\n", @@ -319,17 +299,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "0UHJDA39zf-O", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "0UHJDA39zf-O" }, - "cell_type": "code", + "outputs": [], "source": [ "input_text = []\n", "target_text = []\n", @@ -343,45 +325,43 @@ " \n", "print (np.array(input_text).shape)\n", "print (np.array(target_text).shape)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "MJdfPmdqzf-R", - "colab_type": "text" + "colab_type": "text", + "id": "MJdfPmdqzf-R" }, - "cell_type": "markdown", "source": [ "## Creating batches and shuffling them using tf.data" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "p2pGotuNzf-S", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "p2pGotuNzf-S" }, - "cell_type": "code", + "outputs": [], "source": [ "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "m8gPwEjRzf-Z", - "colab_type": "text" + "colab_type": "text", + "id": "m8gPwEjRzf-Z" }, - "cell_type": "markdown", "source": [ "## Creating the model\n", "\n", @@ -393,17 +373,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "P3KTiiInzf-a", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "P3KTiiInzf-a" }, - "cell_type": "code", + "outputs": [], "source": [ "class Model(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n", @@ -447,66 +429,64 @@ " x = self.fc(output)\n", "\n", " return x, states" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "trpqTWyvk0nr", - "colab_type": "text" + "colab_type": "text", + "id": "trpqTWyvk0nr" }, - "cell_type": "markdown", "source": [ "## Call the model and set the optimizer and the loss function" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "7t2XrzEOzf-e", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "7t2XrzEOzf-e" }, - "cell_type": "code", + "outputs": [], "source": [ "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "dkjWIATszf-h", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "dkjWIATszf-h" }, - "cell_type": "code", + "outputs": [], "source": [ "optimizer = tf.train.AdamOptimizer()\n", "\n", "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n", "def loss_function(real, preds):\n", " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "lPrP0XMUzf-p", - "colab_type": "text" + "colab_type": "text", + "id": "lPrP0XMUzf-p" }, - "cell_type": "markdown", "source": [ "## Train the model\n", "\n", @@ -531,17 +511,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "d4tSNwymzf-q", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "d4tSNwymzf-q" }, - "cell_type": "code", + "outputs": [], "source": [ "# Training step\n", "\n", @@ -574,16 +556,14 @@ " \n", " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "DjGz1tDkzf-u", - "colab_type": "text" + "colab_type": "text", + "id": "DjGz1tDkzf-u" }, - "cell_type": "markdown", "source": [ "## Predicting using our trained model\n", "\n", @@ -601,17 +581,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "WvuwZBX5Ogfd", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "WvuwZBX5Ogfd" }, - "cell_type": "code", + "outputs": [], "source": [ "# Evaluation step(generating text using the model learned)\n", "\n", @@ -648,16 +630,14 @@ " text_generated += idx2char[predicted_id]\n", "\n", "print (start_string + text_generated)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "AM2Uma_-yVIq", - "colab_type": "text" + "colab_type": "text", + "id": "AM2Uma_-yVIq" }, - "cell_type": "markdown", "source": [ "## Next steps\n", "\n", @@ -668,22 +648,42 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "gtEd86sX5cB2", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "gtEd86sX5cB2" }, - "cell_type": "code", + "outputs": [], "source": [ "" - ], - "execution_count": 0, - "outputs": [] + ] } - ] + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "text_generation.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/README.md b/tensorflow/contrib/eager/python/examples/l2hmc/README.md index d6a2ff7558c76c714df1674c4c8c627fa433f197..f171806e379da7213b6ee33e0d454056068fe7a5 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/README.md +++ b/tensorflow/contrib/eager/python/examples/l2hmc/README.md @@ -4,16 +4,15 @@ This folder contains an implementation of [L2HMC](https://arxiv.org/pdf/1711.092 With eager execution enabled, longer sample chains can be handled compared to graph mode, since no graph is explicitly stored. Moreover, with eager execution enabled, there is no need to use a `tf.while_loop`. ## What is L2HMC? -L2HMC is an algorithm that learns a non-volume preserving transformation -for an HMC-like sampling algorithm. More specifically, the non-volume preserving +L2HMC is an adaptive Markov Chain Monte Carlo (MCMC) algorithm that learns a non-volume preserving transformation +for a Hamiltonian Monte Carlo (HMC) sampling algorithm. More specifically, the non-volume preserving transformation is learned with neural nets instantiated within Normalizing Flows -(more precisely, real-NVPs). +(real-NVPs). ## Content - `l2hmc.py`: Dynamics definitions and example energy functions, -including the 2D strongly correlated Gaussian, the rough well energy function, -and a Gaussian mixture model. +including the 2D strongly correlated Gaussian and the rough well energy function, - `l2hmc_test.py`: Unit tests and benchmarks for training a sampler on the energy functions in both eager and graph mode. - `neural_nets.py`: The neural net for learning the kernel on the 2D strongly correlated example. - `main.py`: Run to train a samplers on 2D energy landscapes. @@ -32,7 +31,7 @@ tensorboard and a plot of sampled chain from the trained sampler. Specifying the optional argument `use_defun` will let the program use compiled graphs when running specific sections and improve the overall speed. -## Boosting Performance with `defun` +## Boosting Performance with `tfe.defun` Currently, some models may experience increased overhead with eager execution enabled. To improve performance, we could wrap certain functions with the decorator `@tfe.defun`. For example, we could wrap the function that does the sampling step: diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 1f66d7e75299df0c7db9bc8ec67cb6c0b5d4de40..1ab1b71bd0549e06a1d86611c21faef1f182d740 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -383,6 +383,7 @@ "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", + "N_BATCH = BUFFER_SIZE//BATCH_SIZE\n", "embedding_dim = 256\n", "units = 1024\n", "vocab_inp_size = len(inp_lang.word2idx)\n", @@ -677,21 +678,23 @@ " # using teacher forcing\n", " dec_input = tf.expand_dims(targ[:, t], 1)\n", " \n", - " total_loss += (loss / int(targ.shape[1]))\n", + " batch_loss = (loss / int(targ.shape[1]))\n", + " \n", + " total_loss += batch_loss\n", " \n", " variables = encoder.variables + decoder.variables\n", " \n", " gradients = tape.gradient(loss, variables)\n", - " \n", + " \n", " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - "\n", + " \n", " if batch % 100 == 0:\n", " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", " batch,\n", - " loss.numpy() / int(targ.shape[1])))\n", + " batch_loss.numpy()))\n", " \n", " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", - " total_loss/len(input_tensor)))\n", + " total_loss / N_BATCH))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" ], "execution_count": 0, @@ -906,4 +909,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb index 7c0f9b5b8161a763c4153ebdeece7e0d1b90b384..51b7ffc4de0cee31f7a907ae7bf90f17056f9bcf 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -1,46 +1,30 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "automatic_differentiation.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "t09eeeR5prIJ", - "colab_type": "text" + "colab_type": "text", + "id": "t09eeeR5prIJ" }, - "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "GCCk8_dHpuNf", - "colab_type": "code", + "cellView": "form", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, - "cellView": "form" + "colab_type": "code", + "id": "GCCk8_dHpuNf" }, - "cell_type": "code", + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -53,81 +37,79 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "xh8WkEwWpnm7", - "colab_type": "text" + "colab_type": "text", + "id": "xh8WkEwWpnm7" }, - "cell_type": "markdown", "source": [ "# Automatic differentiation and gradient tape" ] }, { + "cell_type": "markdown", "metadata": { - "id": "idv0bPeCp325", - "colab_type": "text" + "colab_type": "text", + "id": "idv0bPeCp325" }, - "cell_type": "markdown", "source": [ - "
\n", - "\n", - " Run in Google Colab\n", - "\n", - "View source on GitHub
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { + "cell_type": "markdown", "metadata": { - "id": "vDJ4XzMqodTy", - "colab_type": "text" + "colab_type": "text", + "id": "vDJ4XzMqodTy" }, - "cell_type": "markdown", "source": [ "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." ] }, { + "cell_type": "markdown", "metadata": { - "id": "GQJysDM__Qb0", - "colab_type": "text" + "colab_type": "text", + "id": "GQJysDM__Qb0" }, - "cell_type": "markdown", "source": [ "## Setup\n" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "OiMPZStlibBv", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "OiMPZStlibBv" }, - "cell_type": "code", + "outputs": [], "source": [ "import tensorflow as tf\n", "tf.enable_eager_execution()\n", "\n", "tfe = tf.contrib.eager # Shorthand for some symbols" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "1CLWJl0QliB0", - "colab_type": "text" + "colab_type": "text", + "id": "1CLWJl0QliB0" }, - "cell_type": "markdown", "source": [ "## Derivatives of a function\n", "\n", @@ -135,17 +117,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "9FViq92UX7P8", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "9FViq92UX7P8" }, - "cell_type": "code", + "outputs": [], "source": [ "from math import pi\n", "\n", @@ -159,17 +143,15 @@ "# with respect to its arguments. Since f() has a single argument,\n", "# grad_f will return a list with a single element.\n", "grad_f = tfe.gradients_function(f)\n", - "assert tf.abs(grad_f(pi/2)[0]).numpy() < 1e-7" - ], - "execution_count": 0, - "outputs": [] + "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "v9fPs8RyopCf", - "colab_type": "text" + "colab_type": "text", + "id": "v9fPs8RyopCf" }, - "cell_type": "markdown", "source": [ "### Higher-order gradients\n", "\n", @@ -177,17 +159,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "3D0ZvnGYo0rW", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "3D0ZvnGYo0rW" }, - "cell_type": "code", + "outputs": [], "source": [ "def f(x):\n", " return tf.square(tf.sin(x))\n", @@ -205,16 +189,14 @@ "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", "plt.legend()\n", "plt.show()" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "-39gouo7mtgu", - "colab_type": "text" + "colab_type": "text", + "id": "-39gouo7mtgu" }, - "cell_type": "markdown", "source": [ "## Gradient tapes\n", "\n", @@ -225,21 +207,25 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "MH0UfjympWf7", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "MH0UfjympWf7" }, - "cell_type": "code", + "outputs": [], "source": [ "def f(x, y):\n", " output = 1\n", - " for i in range(y):\n", + " # Must use range(int(y)) instead of range(y) in Python 3 when\n", + " # using TensorFlow 1.10 and earlier. Can use range(y) in 1.11+\n", + " for i in range(int(y)):\n", " output = tf.multiply(output, x)\n", " return output\n", "\n", @@ -251,16 +237,14 @@ "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "aNmR5-jhpX2t", - "colab_type": "text" + "colab_type": "text", + "id": "aNmR5-jhpX2t" }, - "cell_type": "markdown", "source": [ "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", "\n", @@ -268,17 +252,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "bAFeIE8EuVIq", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "bAFeIE8EuVIq" }, - "cell_type": "code", + "outputs": [], "source": [ "x = tf.ones((2, 2))\n", " \n", @@ -300,16 +286,14 @@ "for i in [0, 1]:\n", " for j in [0, 1]:\n", " assert dz_dx[i][j].numpy() == 8.0" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "DK05KXrAAld3", - "colab_type": "text" + "colab_type": "text", + "id": "DK05KXrAAld3" }, - "cell_type": "markdown", "source": [ "### Higher-order gradients\n", "\n", @@ -317,17 +301,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "cPQgthZ7ugRJ", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "cPQgthZ7ugRJ" }, - "cell_type": "code", + "outputs": [], "source": [ "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", "\n", @@ -344,21 +330,37 @@ "\n", "assert dy_dx.numpy() == 3.0\n", "assert d2y_dx2.numpy() == 6.0" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "4U1KKzUpNl58", - "colab_type": "text" + "colab_type": "text", + "id": "4U1KKzUpNl58" }, - "cell_type": "markdown", "source": [ "## Next Steps\n", "\n", "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." ] } - ] -} \ No newline at end of file + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "automatic_differentiation.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index 3316dc111431811a17fc4435e619f07ff377b751..4f0d46b1bae3760a63b2abe871034bdedf258f07 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -43,6 +43,27 @@ py_library( ], ) +py_library( + name = "resnet_preprocessing", + srcs = ["resnet_preprocessing.py"], + srcs_version = "PY2AND3", + tags = ["local"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "imagenet_input", + srcs = ["imagenet_input.py"], + srcs_version = "PY2AND3", + tags = ["local"], + deps = [ + ":resnet_preprocessing", + "//tensorflow:tensorflow_py", + ], +) + # Tests cuda_py_test( name = "ops_test", diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md index 21fc44febc8abdc30daad1b35d8434b083360bdf..822d86e9c7a7e620da3b84ded9af98b1c1d4b701 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/README.md +++ b/tensorflow/contrib/eager/python/examples/revnet/README.md @@ -1,19 +1,22 @@ # RevNet with TensorFlow eager execution -This folder contains an TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster. +This folder contains a TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran with both eager and graph execution. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the a redundant forward pass in the implementation by the authors. This saves us from using `tf.stop_gradient` and makes the model run faster. ## Content - `revnet.py`: The RevNet model. - `blocks.py`: The relevant reversible blocks. +- `ops.py`: Auxiliary downsampling operation. - `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100. - `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API. - `config.py`: Configuration file for network architectures and training hyperparameters. - `main.py`: Main training and evaluation script. -- `ops.py`: Auxiliary downsampling operation. +- `main_estimator.py`: Script to train RevNet models on CIFAR-10 and CIFAR-100 with the `tf.estimator` API. +- `main_estimator_tpu.py`: Script to train RevNet models on ImageNet with TPU estimators on Cloud TPUs. +- `resnet_preprocessing.py`, `imagenet_input.py`: Boilerplate to read ImageNet data from TFRecords. -## To run -- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly` +## Train on CIFAR-10/CIFAR-100 +- Make sure you have installed TensorFlow 1.10+ or the latest `tf-nightly` or `tf-nightly-gpu` pip package in order to access the eager execution feature. - First run @@ -24,7 +27,7 @@ python cifar_tfrecords.py --data_dir ${PWD}/cifar to download the cifar dataset and convert them to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100. -- To train a model run +- To train a model, run ```bash python main.py --data_dir ${PWD}/cifar @@ -34,11 +37,75 @@ python main.py --data_dir ${PWD}/cifar - `train_dir`: Directory to store eventfiles and checkpoints. - `restore`: Restore the latest checkpoint. - `validate`: Use validation set for training monitoring. - - `manual_grad`: Use the manually defined gradient map given by the authors. - - `dataset`: Use either `cifar-10` or `cifar-100` + - `dataset`: Use either `cifar-10` or `cifar-100`. + - `config`: RevNet configuration. + - `use_defun`: Use `tfe.defun` to boost performance. + +- To train a model with estimators in graph execution, run + +```bash +python main_estimator.py --data_dir ${PWD}/cifar +``` +To ensure our code works properly when using the Keras model in an estimator, +`tf-nightly` or `tf-nightly-gpu` is highly recommended as of August 2018. + +- Optional arguments for `main.py` include + - `model_dir`: Directory to store eventfiles and checkpoints. + - `dataset`: Use either `cifar-10` or `cifar-100`. + - `config`: RevNet configuration. + - `export`: Export the model for serving if True. + +## Speed up with `tfe.defun` +To ensure that `tf.contrib.eager.defun` in our code works properly with all +part of the model during training, the latest `tf-nightly` or `tf-nightly-gpu` +is highly recommended as of August 2018. + +Even though the speed difference between pure eager execution and graph execution is noticeable, +the difference between fully "defunned" model training and graph +training is negligible. + +## Train on ImageNet with Cloud TPUs +The standard way to train models on Cloud TPUs is via TPU estimators and graph +execution. Models built with the `tf.keras` API are fully compatible with TPU estimators. +To ensure our code works properly in this setting, +`tf-nightly` or `tf-nightly-gpu` is highly recommended as of August 2018. + +### Setup a Google Cloud project + +Follow the instructions at the [Quickstart Guide](https://cloud.google.com/tpu/docs/quickstart) +to get a GCE VM with access to Cloud TPU. + +To run this model, you will need: + +* A GCE VM instance with an associated Cloud TPU resource +* A GCS bucket to store your training checkpoints +* (Optional): The ImageNet training and validation data preprocessed into + TFRecord format, and stored in GCS. + +### Format the data + +The data is expected to be formatted in TFRecord format, as generated by [this +script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py). + +If you do not have ImageNet dataset prepared, you can use a randomly generated +fake dataset to test the model. It is located at +`gs://cloud-tpu-test-datasets/fake_imagenet`. + +### Start training + +Train the model by executing the following command (substituting the appropriate +values): + +```bash +python main_estimator_tpu.py \ + --tpu=$TPU_NAME \ + --data_dir=$DATA_DIR \ + --model_dir=$MODEL_DIR +``` ## Performance -- With the current implementation, RevNet-38 achieves >92% on CIFAR-10 and >71% on CIFAR-100. +- RevNet-38 achieves >92% and >71% accuracy on CIFAR-10 and CIFAR-100 respectively. +- RevNet-56 achieves <26% top-1 error rate on ImageNet. ## Reference The Reversible Residual Network: Backpropagation Without Storing Activations. diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index 8a530b0d71afab6dfc57ed16120a621cafcc3181..f61354bc38a9fcb941f186cac4eac8097eea742d 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -91,32 +91,21 @@ class RevBlock(tf.keras.Model): h = block(h, training=training) return h - def backward_grads_and_vars(self, x, y, dy, training=True): + def backward_grads(self, x, y, dy, training=True): """Apply reversible block backward to outputs.""" grads_all = [] - vars_all = [] - for i in reversed(range(len(self.blocks))): block = self.blocks[i] if i == 0: # First block usually contains downsampling that can't be reversed - with tf.GradientTape() as tape: - tape.watch(x) - y = block(x, training=training) - - grads_combined = tape.gradient( - y, [x] + block.trainable_variables, output_gradients=dy) - dy = grads_combined[0] - grads_all += grads_combined[1:] - vars_all += block.trainable_variables + dy, grads = block.backward_grads_with_downsample( + x, y, dy, training=True) else: - y, dy, grads, vars_ = block.backward_grads_and_vars( - y, dy, training=training) - grads_all += grads - vars_all += vars_ + y, dy, grads = block.backward_grads(y, dy, training=training) + grads_all = grads + grads_all - return dy, grads_all, vars_all + return dy, grads_all class _Residual(tf.keras.Model): @@ -178,10 +167,9 @@ class _Residual(tf.keras.Model): fused=fused, dtype=dtype) - def call(self, x, training=True, concat=True): + def call(self, x, training=True): """Apply residual block to inputs.""" - - x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) + x1, x2 = x f_x2 = self.f(x2, training=training) x1_down = ops.downsample( x1, self.filters // 2, self.strides, axis=self.axis) @@ -190,42 +178,81 @@ class _Residual(tf.keras.Model): y1 = f_x2 + x1_down g_y1 = self.g(y1, training=training) y2 = g_y1 + x2_down - if not concat: # For correct backward grads - return y1, y2 - return tf.concat([y1, y2], axis=self.axis) + return y1, y2 - def backward_grads_and_vars(self, y, dy, training=True): + def backward_grads(self, y, dy, training=True): """Manually compute backward gradients given input and output grads.""" - dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) + dy1, dy2 = dy + y1, y2 = y - with tf.GradientTape(persistent=True) as tape: - tape.watch(y) - y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) + with tf.GradientTape() as gtape: + gtape.watch(y1) gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( + gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) + dg = grads_combined[1:] + dx1 = dy1 + grads_combined[0] + # This doesn't affect eager execution, but improves memory efficiency with + # graphs + with tf.control_dependencies(dg + [dx1]): x2 = y2 - gy1 + + with tf.GradientTape() as ftape: + ftape.watch(x2) fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dx1) + df = grads_combined[1:] + dx2 = dy2 + grads_combined[0] + # Same behavior as above + with tf.control_dependencies(df + [dx2]): x1 = y1 - fx2 - grads_combined = tape.gradient( + x = x1, x2 + dx = dx1, dx2 + grads = df + dg + + return x, dx, grads + + def backward_grads_with_downsample(self, x, y, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + # Splitting this from `backward_grads` for better readability + x1, x2 = x + y1, _ = y + dy1, dy2 = dy + + with tf.GradientTape() as gtape: + gtape.watch(y1) + gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) dg = grads_combined[1:] - dx1 = dy1 + grads_combined[0] + dz1 = dy1 + grads_combined[0] - grads_combined = tape.gradient( - fx2, [x2] + self.f.trainable_variables, output_gradients=dx1) - dx2 = dy2 + grads_combined[0] - df = grads_combined[1:] + # dx1 need one more step to backprop through downsample + with tf.GradientTape() as x1tape: + x1tape.watch(x1) + z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis) + dx1 = x1tape.gradient(z1, x1, output_gradients=dz1) - del tape + with tf.GradientTape() as ftape: + ftape.watch(x2) + fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dz1) + dx2, df = grads_combined[0], grads_combined[1:] - grads = df + dg - vars_ = self.f.trainable_variables + self.g.trainable_variables + # dx2 need one more step to backprop through downsample + with tf.GradientTape() as x2tape: + x2tape.watch(x2) + z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) + dx2 += x2tape.gradient(z2, x2, output_gradients=dy2) - x = tf.concat([x1, x2], axis=self.axis) - dx = tf.concat([dx1, dx2], axis=self.axis) + dx = dx1, dx2 + grads = df + dg - return x, dx, grads, vars_ + return dx, grads # Ideally, the following should be wrapped in `tf.keras.Sequential`, however @@ -422,7 +449,7 @@ class InitBlock(tf.keras.Model): if self.config.init_max_pool: net = self.max_pool(net) - return net + return tf.split(net, num_or_size_splits=2, axis=self.axis) class FinalBlock(tf.keras.Model): @@ -468,7 +495,7 @@ class FinalBlock(tf.keras.Model): self.config.n_classes, dtype=self.config.dtype) def call(self, x, training=True): - net = x + net = tf.concat(x, axis=self.axis) net = self.batch_norm(net, training=training) net = self.activation(net) net = self.global_avg_pool(net) diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py index d74785c8fe1c170ee95172974141c1cfe18b9502..9ff6b605b912772a92ab9e07a0ba5b9325030e43 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -116,70 +116,13 @@ def _validate_block_call_channels_first(block_factory, test): class RevBlockTest(tf.test.TestCase): - def test_call_channels_first(self): - """Test `call` function with `channels_first` data format.""" - if not tf.test.is_gpu_available(): - self.skipTest("GPU not available") - - with tf.device("/gpu:0"): # Default NCHW format - input_shape = (128, 8, 8) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - - # Stride of 1 - block = blocks.RevBlock( - n_res=3, filters=128, strides=(1, 1), input_shape=input_shape) - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 128, 8, 8)) - self.assertNotAllClose(y_tr, y_ev) - - # Stride of 2 - block = blocks.RevBlock( - n_res=3, filters=128, strides=(2, 2), input_shape=input_shape) - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, [16, 128, 4, 4]) - self.assertNotAllClose(y_tr, y_ev) - - def test_call_channels_last(self): - """Test `call` function with `channels_last` data format.""" - with tf.device("/cpu:0"): # NHWC format - input_shape = (8, 8, 128) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - - # Stride 1 - block = blocks.RevBlock( - n_res=3, - filters=128, - strides=(1, 1), - input_shape=input_shape, - data_format="channels_last") - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 8, 8, 128)) - self.assertNotAllClose(y_tr, y_ev) - - # Stride of 2 - block = blocks.RevBlock( - n_res=3, - filters=128, - strides=(2, 2), - input_shape=input_shape, - data_format="channels_last") - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 4, 4, 128)) - self.assertNotAllClose(y_tr, y_ev) - def _check_grad_angle(self, grads, grads_true, atol=1e0): """Check the angle between two list of vectors are all close.""" for g1, g2 in zip(grads, grads_true): degree = compute_degree(g1, g2) self.assertLessEqual(degree, atol) - def test_backward_grads_and_vars_channels_first(self): + def test_backward_grads_channels_first(self): """Test `backward` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") @@ -190,6 +133,7 @@ class RevBlockTest(tf.test.TestCase): data_shape = (16,) + input_shape x = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) block = blocks.RevBlock( n_res=3, filters=128, @@ -199,9 +143,14 @@ class RevBlockTest(tf.test.TestCase): dtype=tf.float64) with tf.GradientTape() as tape: tape.watch(x) - y = block(x, training=True) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) # Compute grads from reconstruction - dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) + (dx1, dx2), dw = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + vars_ = block.trainable_variables # Compute true grads grads = tape.gradient(y, [x] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] @@ -213,6 +162,7 @@ class RevBlockTest(tf.test.TestCase): # Stride 2 x = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) block = blocks.RevBlock( n_res=3, filters=128, @@ -222,9 +172,14 @@ class RevBlockTest(tf.test.TestCase): dtype=tf.float64) with tf.GradientTape() as tape: tape.watch(x) - y = block(x, training=True) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) # Compute grads from reconstruction - dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) + (dx1, dx2), dw = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + vars_ = block.trainable_variables # Compute true grads grads = tape.gradient(y, [x] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] @@ -233,19 +188,44 @@ class RevBlockTest(tf.test.TestCase): self._check_grad_angle(dx_true, dx) self._check_grad_angle(dw_true, dw) + def test_backward_grads_with_nativepy(self): + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") -class _ResidualTest(tf.test.TestCase): + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) - def test_call(self): - """Test `call` function. + # Compute true grads + dx_true = tape.gradient(y, x, output_gradients=dy) - Varying downsampling and data format options. - """ + # Compute grads from reconstruction + (dx1, dx2), _ = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) - _validate_block_call_channels_first(blocks._Residual, self) - _validate_block_call_channels_last(blocks._Residual, self) + thres = 1e-5 + diff_abs = tf.reshape(abs(dx - dx_true), [-1]) + assert all(diff_abs < thres) - def test_backward_grads_and_vars_channels_first(self): + +class _ResidualTest(tf.test.TestCase): + + def test_backward_grads_channels_first(self): """Test `backward_grads` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") @@ -256,6 +236,7 @@ class _ResidualTest(tf.test.TestCase): # Use double precision for testing x_true = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) residual = blocks._Residual( filters=128, strides=(1, 1), @@ -264,16 +245,19 @@ class _ResidualTest(tf.test.TestCase): dtype=tf.float64) with tf.GradientTape() as tape: - x_true = tf.identity(x_true) tape.watch(x_true) - y = residual(x_true, training=True) + x1_true, x2_true = tf.split(x_true, num_or_size_splits=2, axis=1) + y1, y2 = residual((x1_true, x2_true), training=True) + y = tf.concat((y1, y2), axis=1) # Gradients computed due to reversibility - x, dx, dw, vars_ = residual.backward_grads_and_vars( - y, dy=dy, training=True) - + (x1, x2), (dx1, dx2), dw = residual.backward_grads( + y=(y1, y2), dy=(dy1, dy2), training=True) + x = tf.concat((x1, x2), axis=1) + dx = tf.concat((dx1, dx2), axis=1) # True gradients computed by the tape - grads = tape.gradient(y, [x_true] + vars_, output_gradients=dy) + grads = tape.gradient( + y, [x_true] + residual.trainable_variables, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] self.assertAllClose(x_true, x) diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py index 821a4878c1cfe1aebe3697952059266def0f5817..29f1db0e0367515757413c8e47f7b7280fc4cfbb 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/config.py +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -82,7 +82,8 @@ def get_hparams_cifar_38(): config.num_train_images // config.tpu_batch_size) config.add_hparam("tpu_epochs", config.max_train_iter // config.tpu_iters_per_epoch) - + config.add_hparam("tpu_eval_steps", + config.num_eval_images // config.tpu_eval_batch_size) return config @@ -162,7 +163,8 @@ def get_hparams_imagenet_56(): config.num_train_images // config.tpu_batch_size) config.add_hparam("tpu_epochs", config.max_train_iter // config.tpu_iters_per_epoch) - + config.add_hparam("tpu_eval_steps", + config.num_eval_images // config.tpu_eval_batch_size) return config diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py new file mode 100644 index 0000000000000000000000000000000000000000..34a9984b0ecc527ad1991c28146246b716e96c98 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py @@ -0,0 +1,229 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Efficient ImageNet input pipeline using tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.revnet import resnet_preprocessing + + +def image_serving_input_fn(): + """Serving input fn for raw images.""" + + def _preprocess_image(image_bytes): + """Preprocess a single raw image.""" + image = resnet_preprocessing.preprocess_image( + image_bytes=image_bytes, is_training=False) + return image + + image_bytes_list = tf.placeholder( + shape=[None], + dtype=tf.string, + ) + images = tf.map_fn( + _preprocess_image, image_bytes_list, back_prop=False, dtype=tf.float32) + return tf.estimator.export.ServingInputReceiver( + images, {'image_bytes': image_bytes_list}) + + +class ImageNetInput(object): + """Generates ImageNet input_fn for training or evaluation. + + The training data is assumed to be in TFRecord format with keys as specified + in the dataset_parser below, sharded across 1024 files, named sequentially: + train-00000-of-01024 + train-00001-of-01024 + ... + train-01023-of-01024 + + The validation data is in the same format but sharded in 128 files. + + The format of the data required is created by the script at: + https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py + + Args: + is_training: `bool` for whether the input is for training + data_dir: `str` for the directory of the training and validation data; + if 'null' (the literal string 'null', not None), then construct a null + pipeline, consisting of empty images. + use_bfloat16: If True, use bfloat16 precision; else use float32. + transpose_input: 'bool' for whether to use the double transpose trick + num_cores: `int` for the number of TPU cores + """ + + def __init__(self, is_training, + use_bfloat16, + data_dir, + num_cores=8, + num_parallel_calls=64, + image_size=224, + transpose_input=False, + cache=False): + self.image_preprocessing_fn = resnet_preprocessing.preprocess_image + self.is_training = is_training + self.use_bfloat16 = use_bfloat16 + self.data_dir = data_dir + self.num_cores = num_cores + self.num_parallel_calls = num_parallel_calls + if self.data_dir == 'null' or self.data_dir == '': + self.data_dir = None + self.transpose_input = transpose_input + self.image_size = image_size + self.cache = cache + + def set_shapes(self, batch_size, images, labels): + """Statically set the batch_size dimension.""" + if self.transpose_input: + images.set_shape(images.get_shape().merge_with( + tf.TensorShape([None, None, None, batch_size]))) + labels.set_shape(labels.get_shape().merge_with( + tf.TensorShape([batch_size]))) + else: + images.set_shape(images.get_shape().merge_with( + tf.TensorShape([batch_size, None, None, None]))) + labels.set_shape(labels.get_shape().merge_with( + tf.TensorShape([batch_size]))) + + return images, labels + + def dataset_parser(self, value): + """Parse an ImageNet record from a serialized string Tensor.""" + keys_to_features = { + 'image/encoded': tf.FixedLenFeature((), tf.string, ''), + 'image/format': tf.FixedLenFeature((), tf.string, 'jpeg'), + 'image/class/label': tf.FixedLenFeature([], tf.int64, -1), + 'image/class/text': tf.FixedLenFeature([], tf.string, ''), + 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), + 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), + } + + parsed = tf.parse_single_example(value, keys_to_features) + image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) + + image = self.image_preprocessing_fn( + image_bytes=image_bytes, + is_training=self.is_training, + image_size=self.image_size, + use_bfloat16=self.use_bfloat16) + + # Subtract one so that labels are in [0, 1000). + label = tf.cast( + tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32) - 1 + + return image, label + + def input_fn(self, params): + """Input function which provides a single batch for train or eval. + + Args: + params: `dict` of parameters passed from the `TPUEstimator`. + `params['batch_size']` is always provided and should be used as the + effective batch size. + + Returns: + A `tf.data.Dataset` object. + """ + if self.data_dir is None: + tf.logging.info('Using fake input.') + return self.input_fn_null(params) + + # Retrieves the batch size for the current shard. The # of shards is + # computed according to the input pipeline deployment. See + # tf.contrib.tpu.RunConfig for details. + batch_size = params['batch_size'] + + # Shuffle the filenames to ensure better randomization. + file_pattern = os.path.join( + self.data_dir, 'train-*' if self.is_training else 'validation-*') + dataset = tf.data.Dataset.list_files(file_pattern, shuffle=self.is_training) + + if self.is_training and not self.cache: + dataset = dataset.repeat() + + def fetch_dataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) + return dataset + + # Read the data from disk in parallel + dataset = dataset.apply( + tf.contrib.data.parallel_interleave( + fetch_dataset, cycle_length=self.num_parallel_calls, sloppy=True)) + if self.cache: + dataset = dataset.cache().apply( + tf.contrib.data.shuffle_and_repeat(1024 * 16)) + else: + dataset = dataset.shuffle(1024) + + # Use the fused map-and-batch operation. + # + # For XLA, we must used fixed shapes. Because we repeat the source training + # dataset indefinitely, we can use `drop_remainder=True` to get fixed-size + # batches without dropping any training examples. + # + # When evaluating, `drop_remainder=True` prevents accidentally evaluating + # the same image twice by dropping the final batch if it is less than a full + # batch size. As long as this validation is done with consistent batch size, + # exactly the same images will be used. + dataset = dataset.apply( + tf.contrib.data.map_and_batch( + self.dataset_parser, batch_size=batch_size, + num_parallel_batches=self.num_cores, drop_remainder=True)) + + # Transpose for performance on TPU + if self.transpose_input: + dataset = dataset.map( + lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels), + num_parallel_calls=self.num_cores) + + # Assign static batch size dimension + dataset = dataset.map(functools.partial(self.set_shapes, batch_size)) + + # Prefetch overlaps in-feed with training + dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) + return dataset + + def input_fn_null(self, params): + """Input function which provides null (black) images.""" + batch_size = params['batch_size'] + dataset = tf.data.Dataset.range(1).repeat().map(self._get_null_input) + dataset = dataset.prefetch(batch_size) + + dataset = dataset.batch(batch_size, drop_remainder=True) + if self.transpose_input: + dataset = dataset.map( + lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels), + num_parallel_calls=8) + + dataset = dataset.map(functools.partial(self.set_shapes, batch_size)) + + dataset = dataset.prefetch(32) # Prefetch overlaps in-feed with training + tf.logging.info('Input dataset: %s', str(dataset)) + return dataset + + def _get_null_input(self, _): + null_image = tf.zeros([224, 224, 3], tf.bfloat16 + if self.use_bfloat16 else tf.float32) + return (null_image, tf.constant(0, tf.int32)) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index dcd4e1697faae2a06b1b1581d6f6f0cfebeacde1..b702e91f92220c2a9003a1b82411131332012a9e 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -29,6 +29,11 @@ from tensorflow.contrib.eager.python.examples.revnet import revnet tfe = tf.contrib.eager +def apply_gradients(optimizer, grads, vars_, global_step=None): + """Functional style apply_grads for `tfe.defun`.""" + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" tf.enable_eager_execution() @@ -48,6 +53,11 @@ def main(_): if FLAGS.use_defun: model.call = tfe.defun(model.call) + model.compute_gradients = tfe.defun(model.compute_gradients) + model.get_moving_stats = tfe.defun(model.get_moving_stats) + model.restore_moving_stats = tfe.defun(model.restore_moving_stats) + global apply_gradients # pylint:disable=global-variable-undefined + apply_gradients = tfe.defun(apply_gradients) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) @@ -197,9 +207,13 @@ def get_datasets(data_dir, config): def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + logits, saved_hiddens = model(inputs, training=True) + values = model.get_moving_stats() + grads, loss = model.compute_gradients(saved_hiddens, labels) + # Restore moving averages when executing eagerly to avoid updating twice + model.restore_moving_stats(values) + apply_gradients( + optimizer, grads, model.trainable_variables, global_step=global_step) return logits, loss diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py index 4868f1931f8cd9046e6e233c82f95969e355b6c2..3a17eb30da3b989acb0b33f2fcb730da76546c18 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py @@ -53,10 +53,11 @@ def model_fn(features, labels, mode, params): global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=config.momentum) - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, vars_), global_step=global_step) + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) else: @@ -130,8 +131,7 @@ def get_input_fn(config, data_dir, split): return input_fn -def main(argv): - FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name +def main(_): tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration @@ -139,7 +139,7 @@ def main(argv): # Estimator specific configuration run_config = tf.estimator.RunConfig( - model_dir=FLAGS.train_dir, # Directory for storing checkpoints + model_dir=FLAGS.model_dir, # Directory for storing checkpoints tf_random_seed=config.seed, save_summary_steps=config.log_every, save_checkpoints_steps=config.log_every, @@ -153,7 +153,7 @@ def main(argv): # Construct estimator revnet_estimator = tf.estimator.Estimator( model_fn=model_fn, - model_dir=FLAGS.train_dir, + model_dir=FLAGS.model_dir, config=run_config, params={"config": config}) @@ -173,14 +173,14 @@ def main(argv): input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ "image": inputs }) - revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn) + revnet_estimator.export_savedmodel(FLAGS.model_dir, input_fn) if __name__ == "__main__": flags.DEFINE_string( "data_dir", default=None, help="Directory to load tfrecords") flags.DEFINE_string( - "train_dir", + "model_dir", default=None, help="[Optional] Directory to store the training information") flags.DEFINE_string( @@ -197,4 +197,4 @@ if __name__ == "__main__": help="[Optional] Architecture of network. " "Other options include `revnet-110` and `revnet-164`") FLAGS = flags.FLAGS - tf.app.run(main=main, argv=[FLAGS]) + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py index d809bcd287ccf26ef2d817168367f37c933b7182..8520cf5b71af503be35d5415707a283fb363a476 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py @@ -12,22 +12,90 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Cloud TPU Estimator workflow with RevNet train on CIFAR-10.""" +"""Cloud TPU Estimator workflow with RevNet train on ImageNet.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import time from absl import flags import tensorflow as tf -from tensorflow.contrib.eager.python.examples.revnet import cifar_input -from tensorflow.contrib.eager.python.examples.revnet import main as main_ +from tensorflow.contrib import summary +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import imagenet_input from tensorflow.contrib.eager.python.examples.revnet import revnet from tensorflow.contrib.training.python.training import evaluation -from tensorflow.python.estimator import estimator as estimator_ +from tensorflow.python.estimator import estimator + +MEAN_RGB = [0.485, 0.456, 0.406] +STDDEV_RGB = [0.229, 0.224, 0.225] + + +def _host_call_fn(gs, loss, lr): + """Training host call. + + Creates scalar summaries for training metrics. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the + model to the `metric_fn`, provide as part of the `host_call`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `host_call`. + + Args: + gs: `Tensor with shape `[batch]` for the global_step + loss: `Tensor` with shape `[batch]` for the training loss. + lr: `Tensor` with shape `[batch]` for the learning_rate. + + Returns: + List of summary ops to run on the CPU host. + """ + # Host call fns are executed FLAGS.iterations_per_loop times after one + # TPU loop is finished, setting max_queue value to the same as number of + # iterations will make the summary writer only flush the data to storage + # once per loop. + gs = gs[0] + with summary.create_file_writer( + FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default(): + with summary.always_record_summaries(): + summary.scalar("loss", loss[0], step=gs) + summary.scalar("learning_rate", lr[0], step=gs) + return summary.all_summary_ops() + + +def _metric_fn(labels, logits): + """Evaluation metric function. Evaluates accuracy. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the model + to the `metric_fn`, provide as part of the `eval_metrics`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `eval_metrics`. + + Args: + labels: `Tensor` with shape `[batch]`. + logits: `Tensor` with shape `[batch, num_classes]`. + + Returns: + A dict of the metrics to return from evaluation. + """ + predictions = tf.argmax(logits, axis=1) + top_1_accuracy = tf.metrics.accuracy(labels, predictions) + in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) + top_5_accuracy = tf.metrics.mean(in_top_5) + + return { + "top_1_accuracy": top_1_accuracy, + "top_5_accuracy": top_5_accuracy, + } def model_fn(features, labels, mode, params): @@ -42,50 +110,58 @@ def model_fn(features, labels, mode, params): Returns: An instance of `tf.contrib.tpu.TPUEstimatorSpec` """ + revnet_config = params["revnet_config"] + model = revnet.RevNet(config=revnet_config) inputs = features if isinstance(inputs, dict): inputs = features["image"] - FLAGS = params["FLAGS"] # pylint:disable=invalid-name,redefined-outer-name - config = params["config"] - model = revnet.RevNet(config=config) + if revnet_config.data_format == "channels_first": + assert not FLAGS.transpose_input # channels_first only for GPU + inputs = tf.transpose(inputs, [0, 3, 1, 2]) + + if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: + inputs = tf.transpose(inputs, [3, 0, 1, 2]) # HWCN to NHWC + + # Normalize the image to zero mean and unit variance. + inputs -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=inputs.dtype) + inputs /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=inputs.dtype) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.piecewise_constant( - global_step, config.lr_decay_steps, config.lr_list) - optimizer = tf.train.MomentumOptimizer( - learning_rate, momentum=config.momentum) - + global_step, revnet_config.lr_decay_steps, revnet_config.lr_list) + optimizer = tf.train.MomentumOptimizer(learning_rate, + revnet_config.momentum) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) - # Define gradients - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, vars_), global_step=global_step) - - names = [v.name for v in model.variables] - tf.logging.warn("{}".format(names)) + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + if not FLAGS.skip_host_call: + # To log the loss, current learning rate, and epoch for Tensorboard, the + # summary op needs to be run on the host CPU via host_call. host_call + # expects [batch_size, ...] Tensors, thus reshape to introduce a batch + # dimension. These Tensors are implicitly concatenated to + # [params['batch_size']]. + gs_t = tf.reshape(global_step, [1]) + loss_t = tf.reshape(loss, [1]) + lr_t = tf.reshape(learning_rate, [1]) + host_call = (_host_call_fn, [gs_t, loss_t, lr_t]) return tf.contrib.tpu.TPUEstimatorSpec( - mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op) + mode=mode, loss=loss, train_op=train_op, host_call=host_call) elif mode == tf.estimator.ModeKeys.EVAL: logits, _ = model(inputs, training=False) loss = model.compute_loss(labels=labels, logits=logits) - def metric_fn(labels, logits): - predictions = tf.argmax(logits, axis=1) - accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions) - return { - "accuracy": accuracy, - } - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) + mode=mode, loss=loss, eval_metrics=(_metric_fn, [labels, logits])) else: # Predict or export logits, _ = model(inputs, training=False) @@ -102,117 +178,75 @@ def model_fn(features, labels, mode, params): }) -def get_input_fn(config, data_dir, split): - """Get the input function required by the `tf.contrib.tpu.TPUEstimator` API. - - Args: - config: Customized hyperparameters - data_dir: Directory where the data is stored - split: One of `train`, `validation`, `train_all`, and `test` - - Returns: - Input function required by the `tf.contrib.tpu.TPUEstimator` API - """ - - data_dir = os.path.join(data_dir, config.dataset) - # Fix split-dependent hyperparameters - if split == "train_all" or split == "train": - data_aug = True - epochs = config.tpu_epochs - shuffle = True - else: - data_aug = False - epochs = 1 - shuffle = False - - def input_fn(params): - """Input function required by the `tf.contrib.tpu.TPUEstimator` API.""" - batch_size = params["batch_size"] - return cifar_input.get_ds_from_tfrecords( - data_dir=data_dir, - split=split, - data_aug=data_aug, - batch_size=batch_size, # per-shard batch size - epochs=epochs, - shuffle=shuffle, - prefetch=batch_size, # per-shard batch size - data_format=config.data_format) - - return input_fn - - -def main(argv): - FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name +def main(_): tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration - config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) + revnet_config = { + "revnet-56": config_.get_hparams_imagenet_56(), + "revnet-104": config_.get_hparams_imagenet_104() + }[FLAGS.revnet_config] if FLAGS.use_tpu: - tf.logging.info("Using TPU.") - tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( - FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - else: - tpu_cluster_resolver = None - - # TPU specific configuration - tpu_config = tf.contrib.tpu.TPUConfig( - # Recommended to be set as number of global steps for next checkpoint - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_shards) + revnet_config.data_format = "channels_last" + + tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) # Estimator specific configuration - run_config = tf.contrib.tpu.RunConfig( + config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, session_config=tf.ConfigProto( - allow_soft_placement=True, log_device_placement=False), - tpu_config=tpu_config, + allow_soft_placement=True, log_device_placement=True), + tpu_config=tf.contrib.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_shards, + per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig. + PER_HOST_V2), ) - # Construct TPU Estimator - estimator = tf.contrib.tpu.TPUEstimator( + # Input pipelines are slightly different (with regards to shuffling and + # preprocessing) between training and evaluation. + imagenet_train, imagenet_eval = [ + imagenet_input.ImageNetInput( + is_training=is_training, + data_dir=FLAGS.data_dir, + transpose_input=FLAGS.transpose_input, + use_bfloat16=False) for is_training in [True, False] + ] + + revnet_classifier = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, use_tpu=FLAGS.use_tpu, - train_batch_size=config.tpu_batch_size, - eval_batch_size=config.tpu_eval_batch_size, - config=run_config, - params={ - "FLAGS": FLAGS, - "config": config, - }) - - # Construct input functions - train_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="train_all") - eval_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="test") - - # Disabling a range within an else block currently doesn't work - # due to https://github.com/PyCQA/pylint/issues/872 + train_batch_size=revnet_config.tpu_batch_size, + eval_batch_size=revnet_config.tpu_eval_batch_size, + config=config, + export_to_tpu=False, + params={"revnet_config": revnet_config}) + + steps_per_epoch = revnet_config.tpu_iters_per_epoch + eval_steps = revnet_config.tpu_eval_steps + # pylint: disable=protected-access if FLAGS.mode == "eval": - # TPUEstimator.evaluate *requires* a steps argument. - # Note that the number of examples used during evaluation is - # --eval_steps * --batch_size. - # So if you change --batch_size then change --eval_steps too. - eval_steps = 10000 // config.tpu_eval_batch_size - # Run evaluation when there's a new checkpoint for ckpt in evaluation.checkpoints_iterator( FLAGS.model_dir, timeout=FLAGS.eval_timeout): tf.logging.info("Starting to evaluate.") try: start_timestamp = time.time() # This time will include compilation time - eval_results = estimator.evaluate( - input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt) + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, + steps=eval_steps, + checkpoint_path=ckpt) elapsed_time = int(time.time() - start_timestamp) tf.logging.info("Eval results: %s. Elapsed seconds: %d" % (eval_results, elapsed_time)) # Terminate eval job when final checkpoint is reached current_step = int(os.path.basename(ckpt).split("-")[1]) - if current_step >= config.max_train_iter: + if current_step >= revnet_config.max_train_iter: tf.logging.info( "Evaluation finished after training step %d" % current_step) break @@ -226,37 +260,56 @@ def main(argv): "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' - current_step = estimator_._load_global_step_from_checkpoint_dir( + current_step = estimator._load_global_step_from_checkpoint_dir( FLAGS.model_dir) - tf.logging.info("Training for %d steps . Current" - " step %d." % (config.max_train_iter, current_step)) + + tf.logging.info( + "Training for %d steps (%.2f epochs in total). Current" + " step %d." % (revnet_config.max_train_iter, + revnet_config.max_train_iter / steps_per_epoch, + current_step)) start_timestamp = time.time() # This time will include compilation time + if FLAGS.mode == "train": - estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter) + revnet_classifier.train( + input_fn=imagenet_train.input_fn, + max_steps=revnet_config.max_train_iter) + else: - eval_steps = 10000 // config.tpu_eval_batch_size assert FLAGS.mode == "train_and_eval" - while current_step < config.max_train_iter: + while current_step < revnet_config.max_train_iter: # Train for up to steps_per_eval number of steps. # At the end of training, a checkpoint will be written to --model_dir. next_checkpoint = min(current_step + FLAGS.steps_per_eval, - config.max_train_iter) - estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint) + revnet_config.max_train_iter) + revnet_classifier.train( + input_fn=imagenet_train.input_fn, max_steps=next_checkpoint) current_step = next_checkpoint + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (next_checkpoint, int(time.time() - start_timestamp))) + # Evaluate the model on the most recent model in --model_dir. # Since evaluation happens in batches of --eval_batch_size, some images - # may be consistently excluded modulo the batch size. + # may be excluded modulo the batch size. As long as the batch size is + # consistent, the evaluated images are also consistent. tf.logging.info("Starting to evaluate.") - eval_results = estimator.evaluate( - input_fn=eval_input_fn, steps=eval_steps) + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, steps=eval_steps) tf.logging.info("Eval results: %s" % eval_results) - elapsed_time = int(time.time() - start_timestamp) - tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % - (config.max_train_iter, elapsed_time)) - # pylint: enable=protected-access + elapsed_time = int(time.time() - start_timestamp) + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (revnet_config.max_train_iter, elapsed_time)) + + if FLAGS.export_dir is not None: + # The guide to serve an exported TensorFlow model is at: + # https://www.tensorflow.org/serving/serving_basic + tf.logging.info("Starting to export model.") + revnet_classifier.export_savedmodel( + export_dir_base=FLAGS.export_dir, + serving_input_receiver_fn=imagenet_input.image_serving_input_fn) if __name__ == "__main__": @@ -288,14 +341,10 @@ if __name__ == "__main__": default=None, help="[Optional] Directory to store the model information") flags.DEFINE_string( - "dataset", - default="cifar-10", - help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") - flags.DEFINE_string( - "config", - default="revnet-38", + "revnet_config", + default="revnet-56", help="[Optional] Architecture of network. " - "Other options include `revnet-110` and `revnet-164`") + "Other options include `revnet-104`") flags.DEFINE_boolean( "use_tpu", default=True, help="[Optional] Whether to use TPU") flags.DEFINE_integer( @@ -309,20 +358,37 @@ if __name__ == "__main__": " train steps, the loop will exit before reaching" " --iterations_per_loop. The larger this value is, the higher the" " utilization on the TPU.")) - flags.DEFINE_string( - "mode", - default="train_and_eval", - help="[Optional] Mode to run: train, eval, train_and_eval") flags.DEFINE_integer( - "eval_timeout", 60 * 60 * 24, - "Maximum seconds between checkpoints before evaluation terminates.") + "eval_timeout", + default=None, + help="Maximum seconds between checkpoints before evaluation terminates.") flags.DEFINE_integer( "steps_per_eval", - default=1000, + default=5000, help=( "Controls how often evaluation is performed. Since evaluation is" " fairly expensive, it is advised to evaluate as infrequently as" " possible (i.e. up to --train_steps, which evaluates the model only" " after finishing the entire training regime).")) + flags.DEFINE_bool( + "transpose_input", + default=True, + help="Use TPU double transpose optimization") + flags.DEFINE_string( + "export_dir", + default=None, + help=("The directory where the exported SavedModel will be stored.")) + flags.DEFINE_bool( + "skip_host_call", + default=False, + help=("Skip the host_call which is executed every training step. This is" + " generally used for generating training summaries (train loss," + " learning rate, etc...). When --skip_host_call=false, there could" + " be a performance drop if host_call function is slow and cannot" + " keep up with the TPU-side computation.")) + flags.DEFINE_string( + "mode", + default="train_and_eval", + help='One of {"train_and_eval", "train", "eval"}.') FLAGS = flags.FLAGS - tf.app.run(main=main, argv=[FLAGS]) + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py b/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..21a1ab85d46cde11453e1f693cc4aabbbf3c90ed --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ImageNet preprocessing for ResNet.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +IMAGE_SIZE = 224 +CROP_PADDING = 32 + + +def distorted_bounding_box_crop(image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100, + scope=None): + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image_bytes: `Tensor` of binary image data. + bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` + where each coordinate is [0, 1) and the coordinates are arranged + as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole + image. + min_object_covered: An optional `float`. Defaults to `0.1`. The cropped + area of the image must contain at least this fraction of any bounding + box supplied. + aspect_ratio_range: An optional list of `float`s. The cropped area of the + image must have an aspect ratio = width / height within this range. + area_range: An optional list of `float`s. The cropped area of the image + must contain a fraction of the supplied image within in this range. + max_attempts: An optional `int`. Number of attempts at generating a cropped + region of the image of the specified constraints. After `max_attempts` + failures, return the entire image. + scope: Optional `str` for name scope. + Returns: + cropped image `Tensor` + """ + with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]): + shape = tf.image.extract_jpeg_shape(image_bytes) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + + return image + + +def _at_least_x_are_equal(a, b, x): + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _decode_and_random_crop(image_bytes, image_size): + """Make a random crop of image_size.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = distorted_bounding_box_crop( + image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(3. / 4, 4. / 3.), + area_range=(0.08, 1.0), + max_attempts=10, + scope=None) + original_shape = tf.image.extract_jpeg_shape(image_bytes) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond( + bad, + lambda: _decode_and_center_crop(image_bytes, image_size), + lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda + [image_size, image_size])[0]) + + return image + + +def _decode_and_center_crop(image_bytes, image_size): + """Crops to center of image with padding then scales image_size.""" + shape = tf.image.extract_jpeg_shape(image_bytes) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + CROP_PADDING)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), + tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack([offset_height, offset_width, + padded_center_crop_size, padded_center_crop_size]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + image = tf.image.resize_bicubic([image], [image_size, image_size])[0] + + return image + + +def _flip(image): + """Random horizontal image flip.""" + image = tf.image.random_flip_left_right(image) + return image + + +def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_random_crop(image_bytes, image_size) + image = _flip(image) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_center_crop(image_bytes, image_size) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_image(image_bytes, + is_training=False, + use_bfloat16=False, + image_size=IMAGE_SIZE): + """Preprocesses the given image. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + is_training: `bool` for whether the preprocessing is for training. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + if is_training: + return preprocess_for_train(image_bytes, use_bfloat16, image_size) + else: + return preprocess_for_eval(image_bytes, use_bfloat16, image_size) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index b1cb312b7459eb1d8926e6e6635ed2cfbed79833..1f2cb14972f0b92d29489adff8f94e790e1ec4ed 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -24,7 +24,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks @@ -45,6 +44,7 @@ class RevNet(tf.keras.Model): self._init_block = blocks.InitBlock(config=self.config) self._final_block = blocks.FinalBlock(config=self.config) self._block_list = self._construct_intermediate_blocks() + self._moving_average_variables = [] def _construct_intermediate_blocks(self): # Precompute input shape after initial block @@ -128,126 +128,90 @@ class RevNet(tf.keras.Model): return tf.reduce_mean(cross_ent) - def compute_gradients(self, inputs, labels, training=True, l2_reg=True): + def compute_gradients(self, saved_hidden, labels, training=True, l2_reg=True): """Manually computes gradients. - When eager execution is enabled, this method also SILENTLY updates the - running averages of batch normalization when `training` is set to True. + This method silently updates the running averages of batch normalization. Args: - inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` + saved_hidden: List of hidden states Tensors labels: One-hot labels for classification training: Use the mini-batch stats in batch norm if set to True l2_reg: Apply l2 regularization Returns: - A tuple with the first entry being a list of all gradients, the second - entry being a list of respective variables, the third being the logits, - and the forth being the loss + A tuple with the first entry being a list of all gradients and the second + being the loss """ - # Run forward pass to record hidden states - vars_and_vals = self.get_moving_stats() - _, saved_hidden = self(inputs, training=training) # pylint:disable=not-callable - if tf.executing_eagerly(): - # Restore moving averages when executing eagerly to avoid updating twice - self.restore_moving_stats(vars_and_vals) - else: - # Fetch batch norm updates in graph mode - updates = self.get_updates_for(inputs) - - grads_all = [] - vars_all = [] + def _defunable_pop(l): + """Functional style list pop that works with `tfe.defun`.""" + t, l = l[-1], l[:-1] + return t, l - # Manually backprop through last block + # Backprop through last block x = saved_hidden[-1] with tf.GradientTape() as tape: tape.watch(x) - # Running stats updated here logits = self._final_block(x, training=training) loss = self.compute_loss(logits, labels) - grads_combined = tape.gradient(loss, [x] + self._final_block.trainable_variables) - dy, grads_ = grads_combined[0], grads_combined[1:] - grads_all += grads_ - vars_all += self._final_block.trainable_variables + dy, final_grads = grads_combined[0], grads_combined[1:] - # Manually backprop through intermediate blocks + # Backprop through intermediate blocks + intermediate_grads = [] for block in reversed(self._block_list): - y = saved_hidden.pop() + y, saved_hidden = _defunable_pop(saved_hidden) x = saved_hidden[-1] - # Running stats updated here - dy, grads, vars_ = block.backward_grads_and_vars( - x, y, dy, training=training) - grads_all += grads - vars_all += vars_ - - # Manually backprop through first block - saved_hidden.pop() - x = saved_hidden.pop() - assert not saved_hidden # Cleared after backprop + dy, grads = block.backward_grads(x, y, dy, training=training) + intermediate_grads = grads + intermediate_grads + # Backprop through first block + _, saved_hidden = _defunable_pop(saved_hidden) + x, saved_hidden = _defunable_pop(saved_hidden) + assert not saved_hidden with tf.GradientTape() as tape: - # Running stats updated here y = self._init_block(x, training=training) - - grads_all += tape.gradient( + init_grads = tape.gradient( y, self._init_block.trainable_variables, output_gradients=dy) - vars_all += self._init_block.trainable_variables - # Apply weight decay + # Ordering match up with `model.trainable_variables` + grads_all = init_grads + final_grads + intermediate_grads if l2_reg: - grads_all = self._apply_weight_decay(grads_all, vars_all) - - if not tf.executing_eagerly(): - # Force updates to be executed before gradient computation in graph mode - # This does nothing when the function is wrapped in defun - with tf.control_dependencies(updates): - grads_all[0] = tf.identity(grads_all[0]) + grads_all = self._apply_weight_decay(grads_all) - return grads_all, vars_all, logits, loss + return grads_all, loss - def _apply_weight_decay(self, grads, vars_): + def _apply_weight_decay(self, grads): """Update gradients to reflect weight decay.""" - # Don't decay bias return [ g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g - for g, v in zip(grads, vars_) + for g, v in zip(grads, self.trainable_variables) ] def get_moving_stats(self): - """Get moving averages of batch normalization. - - This is needed to avoid updating the running average twice in one iteration. - - Returns: - A dictionary mapping variables for batch normalization moving averages - to their current values. - """ - vars_and_vals = {} - - def _is_moving_var(v): - n = v.name - return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + """Get moving averages of batch normalization.""" + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + return [v.read_value() for v in self.moving_average_variables] + def restore_moving_stats(self, values): + """Restore moving averages of batch normalization.""" device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" with tf.device(device): - for v in filter(_is_moving_var, self.variables): - vars_and_vals[v] = v.read_value() + for var_, val in zip(self.moving_average_variables, values): + var_.assign(val) - return vars_and_vals + @property + def moving_average_variables(self): + """Get all variables that are batch norm moving averages.""" - def restore_moving_stats(self, vars_and_vals): - """Restore moving averages of batch normalization. + def _is_moving_avg(v): + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") - This is needed to avoid updating the running average twice in one iteration. + if not self._moving_average_variables: + self._moving_average_variables = filter(_is_moving_avg, self.variables) - Args: - vars_and_vals: The dictionary mapping variables to their previous values. - """ - device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" - with tf.device(device): - for var_, val in six.iteritems(vars_and_vals): - # `assign` causes a copy to GPU (if variable is already on GPU) - var_.assign(val) + return self._moving_average_variables diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 26b084752330ec6ccc3bbf34bcbfeb95a7429907..84b2ddf0de0739936d458ae1bce832cfbb167d64 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -31,9 +31,11 @@ tfe = tf.contrib.eager def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + logits, saved_hidden = model(inputs) + grads, loss = model.compute_gradients( + saved_hidden=saved_hidden, labels=labels) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) return logits, loss @@ -96,9 +98,10 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients(self): """Test `compute_gradients` function.""" - self.model(self.x, training=False) # Initialize model - grads, vars_, logits, loss = self.model.compute_gradients( - inputs=self.x, labels=self.t, training=True, l2_reg=True) + _, saved_hidden = self.model(self.x) # Initialize model + grads, loss = self.model.compute_gradients( + saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -107,7 +110,7 @@ class RevNetTest(tf.test.TestCase): # Compare against the true gradient computed by the tape with tf.GradientTape() as tape: - logits, _ = self.model(self.x, training=True) + logits, _ = self.model(self.x) loss_true = self.model.compute_loss(logits=logits, labels=self.t) grads_true = tape.gradient(loss_true, vars_) self.assertAllClose(loss, loss_true) @@ -122,7 +125,9 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" compute_gradients = tfe.defun(self.model.compute_gradients) - grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True) + _, saved_hidden = self.model(self.x) + grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -146,10 +151,11 @@ class RevNetTest(tf.test.TestCase): dtype=tf.int32) global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) - grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True) + _, saved_hidden = model(x) + grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train_op = optimizer.apply_gradients( - zip(grads_all, vars_all), global_step=global_step) + zip(grads, model.trainable_variables), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD deleted file mode 100644 index b470a41d815ce650731680065cc7341f844e3fdc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/BUILD +++ /dev/null @@ -1,59 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -# Model -py_library( - name = "config", - srcs = ["config.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "ops", - srcs = ["ops.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "sagan", - srcs = ["sagan.py"], - srcs_version = "PY2AND3", - deps = [ - ":ops", - "//tensorflow:tensorflow_py", - ], -) - -# Tests -cuda_py_test( - name = "ops_test", - size = "small", - srcs = ["ops_test.py"], - additional_deps = [ - ":ops", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "sagan_test", - size = "large", - srcs = ["sagan_test.py"], - additional_deps = [ - ":config", - ":sagan", - "//tensorflow:tensorflow_py", - ], - tags = [ - "optonly", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py deleted file mode 100644 index 1967bbd867447d9deaf9a7cb3b22a38889276a50..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/config.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Configuration in format of tf.contrib.training.HParams. -Supports default 128x128 ImageNet. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -tfe = tf.contrib.eager - - -def get_hparams_imagenet(): - """Configurations to train SAGAN on 128x128 ImageNet dataset.""" - config = tf.contrib.training.HParams() - if tf.test.is_gpu_available(): - config.add_hparam("image_shape", (3, 128, 128)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (512, 4, 4)) - else: - config.add_hparam("image_shape", (128, 128, 3)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (4, 4, 512)) - - config.add_hparam("latent_dim", 128) - config.add_hparam("update_g_once_every", 1) - config.add_hparam("batch_size", 64) - config.add_hparam("d_init_filters", 32) - config.add_hparam("num_upsamples", 5) - # (512, 4, 4) -> (3, 128, 128) - return config - - -def get_hparams_mock(): - """Configurations of smaller networks for testing.""" - config = tf.contrib.training.HParams() - if tf.test.is_gpu_available(): - config.add_hparam("image_shape", (3, 16, 16)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (32, 2, 2)) - else: - config.add_hparam("image_shape", (16, 16, 3)) - config.add_hparam("data_format", "channels_last") - config.add_hparam("g_init_shape", (2, 2, 32)) - - config.add_hparam("latent_dim", 16) - config.add_hparam("update_g_once_every", 1) - config.add_hparam("batch_size", 2) - config.add_hparam("d_init_filters", 4) - config.add_hparam("num_upsamples", 3) - # (32, 2, 2) -> (3, 16, 16) - return config diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py deleted file mode 100644 index 9a03cab1d12fc16baa7343f72ac58ccd39f698bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/ops.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Auxiliary operations. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - - -def flatten_hw(x, data_format="channels_first"): - """Flatten the input tensor across height and width dimensions.""" - if data_format == "channels_last": - x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first` - - old_shape = tf.shape(x) - new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]] - - return tf.reshape(x, new_shape) - - -def broaden_hw(x, h, w, c, data_format="channels_first"): - """Broaden dimension so that output has height and width.""" - if data_format == "channels_first": - shape = [-1, c, h, w] - else: - shape = [-1, h, w, c] - - return tf.reshape(x, shape) - - -class BroadenHW(tf.keras.layers.Layer): - """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`.""" - - def __init__(self, h, w, c, data_format="channels_first"): - super(BroadenHW, self).__init__() - self.h = h - self.w = w - self.c = c - self.data_format = data_format - - def call(self, x): - return broaden_hw( - x, h=self.h, w=self.w, c=self.c, data_format=self.data_format) - - def compute_output_shape(self, input_shape): - input_shape = tf.TensorShape(input_shape).as_list() - if self.data_format == "channels_first": - output_shape = (input_shape[0], self.c, self.h, self.w) - else: - output_shape = (input_shape[0], self.h, self.w, self.c) - - return tf.TensorShape(output_shape) diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py deleted file mode 100644 index 3454985904215b59d27fc4b76ccb4a8c2c2eff00..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for auxiliary operations.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import ops - - -class OpsTest(tf.test.TestCase): - - def test_flatten_hw(self): - """Test `flatten_hw` function with mock object.""" - - batch_size = 1 - # Default NCHW format - if tf.test.is_gpu_available(): - x = tf.random_normal(shape=(batch_size, 3, 4, 4)) - y = ops.flatten_hw(x, data_format="channels_first") - self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) - - # NHWC format - x = tf.random_normal(shape=(batch_size, 4, 4, 3)) - y = ops.flatten_hw(x, data_format="channels_last") - self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) - - def test_broaden_hw(self): - """Test `broaden_hw` function with mock object.""" - - batch_size = 1 - # NHWC format - x = tf.random_normal(shape=[batch_size, 4 * 4 * 16]) - y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last") - self.assertEqual(y.shape, (batch_size, 4, 4, 16)) - - # Default NCHW format - if tf.test.is_gpu_available(): - y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first") - self.assertEqual(y.shape, (batch_size, 16, 4, 4)) - - -if __name__ == "__main__": - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py deleted file mode 100644 index 81304149851675e07a3c7f9ad92697da2017022b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Code for main model. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import ops -tfe = tf.contrib.eager - - -class SelfAttentionModule(tf.keras.Model): - """Self-attention module composed of convolutional layers.""" - - def __init__(self, - attention_features, - original_features, - data_format="channels_first"): - """Initialize the module. - - Args: - attention_features: Number of filters for the attention computation. - original_features: Number of filters of the original Tensor. - data_format: Either 'channels_first' or 'channels_last' - """ - super(SelfAttentionModule, self).__init__() - self.data_format = data_format - # Matrix multiplication implemented as 2D Convolution - self.f = tf.keras.layers.Conv2D( - filters=attention_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.g = tf.keras.layers.Conv2D( - filters=attention_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.h = tf.keras.layers.Conv2D( - filters=original_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.scale = tf.Variable(0., trainable=True) - - def call(self, x): - f = self.f(x) - g = self.g(x) - h = self.h(x) - - f_flatten = ops.flatten_hw(f, data_format=self.data_format) - g_flatten = ops.flatten_hw(g, data_format=self.data_format) - h_flatten = ops.flatten_hw(h, data_format=self.data_format) - - s = tf.matmul(g_flatten, f_flatten, transpose_b=True) - b = tf.nn.softmax(s, axis=-1) - o = tf.matmul(b, h_flatten) - y = self.scale * tf.reshape(o, tf.shape(x)) + x - - return y - - def compute_output_shape(self, input_shape): - return input_shape - - -class SAGAN(tf.contrib.checkpoint.Checkpointable): - """Self-attention generative adversarial network.""" - - def __init__(self, config): - """Initialize the model. - - Args: - config: tf.contrib.training.HParams object; specifies hyperparameters - """ - super(SAGAN, self).__init__() - self.config = config - self.generator = self._construct_generator() - self.discriminator = self._construct_discriminator() - - def _construct_generator(self): - """Construct generator.""" - # TODO(lxuechen): Add spectral normalization for WGAN - axis = 1 if self.config.data_format == "channels_first" else 3 - - generator = tf.keras.Sequential() - generator.add( - tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,))) - generator.add( - tf.keras.layers.Dense( - units=np.prod(self.config.g_init_shape), activation=tf.nn.relu)) - - if self.config.data_format == "channels_first": - c, h, w = self.config.g_init_shape - else: - h, w, c = self.config.g_init_shape - - # Reshape to NHWC/NCHW - generator.add( - ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format)) - - filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)] - filters_list[-1] = 3 # Standard RGB images - - for filters in filters_list[:len(filters_list) // 2]: - generator.add( - tf.keras.layers.Conv2DTranspose( - filters=filters, - kernel_size=4, - strides=(2, 2), - use_bias=False, - padding="SAME", - data_format=self.config.data_format)) - generator.add(tf.keras.layers.BatchNormalization(axis=axis)) - generator.add(tf.keras.layers.Activation("relu")) - - # pylint: disable=undefined-loop-variable - generator.add( - SelfAttentionModule( - original_features=filters, - attention_features=filters // 8, - data_format=self.config.data_format)) - # pylint: enable=undefined-loop-variable - - for filters in filters_list[len(filters_list) // 2:]: - generator.add( - tf.keras.layers.Conv2DTranspose( - filters=filters, - kernel_size=4, - strides=(2, 2), - use_bias=False, - padding="SAME", - data_format=self.config.data_format)) - if filters == 3: - # Assume Image rescaled to [-1, 1] - generator.add(tf.keras.layers.Activation("tanh")) - else: - generator.add(tf.keras.layers.BatchNormalization(axis=axis)) - generator.add(tf.keras.layers.Activation("relu")) - - return generator - - def _construct_discriminator(self): - """Construct discriminator.""" - # TODO(lxuechen): Add spectral normalization for WGAN - discriminator = tf.keras.Sequential() - discriminator.add( - tf.keras.layers.InputLayer(input_shape=self.config.image_shape)) - - filters_list = [ - self.config.d_init_filters * 2**p - for p in range(self.config.num_upsamples) - ] - - for filters in filters_list[:(len(filters_list) + 1) // 2]: - discriminator.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=4, - strides=(2, 2), - padding="SAME", - data_format=self.config.data_format)) - discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) - - # pylint: disable=undefined-loop-variable - discriminator.add( - SelfAttentionModule( - original_features=filters, - attention_features=filters // 8, - data_format=self.config.data_format)) - # pylint: enable=undefined-loop-variable - - for filters in filters_list[(len(filters_list) + 1) // 2:]: - discriminator.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=4, - strides=(2, 2), - padding="SAME", - data_format=self.config.data_format)) - discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) - - discriminator.add(tf.keras.layers.Flatten()) - discriminator.add(tf.keras.layers.Dense(units=1)) - - return discriminator - - def compute_loss_and_grads(self, real_images, noise, training=True): - """Compute loss and gradients for both generator and discriminator.""" - # TODO(lxuechen): Add gradient penalty for discriminator - with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: - real_logits = self.discriminator(real_images, training=training) - - fake_images = self.generator.call(noise, training=training) - fake_logits = self.discriminator.call(fake_images) - - g_loss = self.compute_g_loss(fake_logits) - d_loss = self.compute_d_loss(fake_logits, real_logits) - - g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables) - d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables) - - return g_loss, d_loss, g_grads, d_grads - - def compute_g_loss(self, fake_logits): - return -tf.reduce_mean(fake_logits) # Hinge loss - - def compute_d_loss(self, fake_logits, real_logits): - # Hinge loss - real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits)) - fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits)) - return real_loss + fake_loss diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py deleted file mode 100644 index 18345945108111b57c5401c26b7dca0bfc8f8316..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for self-attention generative adversarial network.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import config as config_ -from tensorflow.contrib.eager.python.examples.sagan import sagan -tfe = tf.contrib.eager - - -class SAGANTest(tf.test.TestCase): - - def setUp(self): - super(SAGANTest, self).setUp() - config = config_.get_hparams_mock() - self.noise_shape = (config.batch_size, config.latent_dim) - self.logits_shape = (config.batch_size, 1) - self.images_shape = (config.batch_size,) + config.image_shape - - self.model = sagan.SAGAN(config=config) - self.noise = tf.random_normal(shape=self.noise_shape) - self.real_images = tf.random_normal(shape=self.images_shape) - self.config = config - - def tearDown(self): - del self.model - del self.noise - del self.real_images - super(SAGANTest, self).tearDown() - - def test_generator_call(self): - """Test `generator.__call__` function.""" - fake_images = self.model.generator(self.noise, training=False) - self.assertEqual(fake_images.shape, self.images_shape) - - def test_generator_call_defun(self): - """Test `generator.__call__` function with defun.""" - call_ = tfe.defun(self.model.generator.__call__) - fake_images = call_(self.noise, training=False) - self.assertEqual(fake_images.shape, self.images_shape) - - def test_discriminator_call(self): - """Test `discriminator.__call__` function.""" - real_logits = self.model.discriminator(self.real_images) - self.assertEqual(real_logits.shape, self.logits_shape) - - def test_discriminator_call_defun(self): - """Test `discriminator.__call__` function with defun.""" - call_ = tfe.defun(self.model.discriminator.__call__) - real_logits = call_(self.real_images) - self.assertEqual(real_logits.shape, self.logits_shape) - - def test_compute_loss_and_grads(self): - """Test `compute_loss_and_grads` function.""" - g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads( - self.real_images, self.noise, training=False) - self.assertEqual(g_loss.shape, ()) - self.assertEqual(d_loss.shape, ()) - self.assertTrue(isinstance(g_grads, list)) - self.assertTrue(isinstance(d_grads, list)) - g_vars = self.model.generator.trainable_variables - d_vars = self.model.discriminator.trainable_variables - - self.assertEqual(len(g_grads), len(g_vars)) - self.assertEqual(len(d_grads), len(d_vars)) - - def test_compute_loss_and_grads_defun(self): - """Test `compute_loss_and_grads` function with defun.""" - compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads) - g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads( - self.real_images, self.noise, training=False) - self.assertEqual(g_loss.shape, ()) - self.assertEqual(d_loss.shape, ()) - self.assertTrue(isinstance(g_grads, list)) - self.assertTrue(isinstance(d_grads, list)) - g_vars = self.model.generator.trainable_variables - d_vars = self.model.discriminator.trainable_variables - - self.assertEqual(len(g_grads), len(g_vars)) - self.assertEqual(len(d_grads), len(d_vars)) - - -if __name__ == "__main__": - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 8ac553e0ae71382966d03d9ef4429adf5137b369..d18a097063c7d25947af3e2e2959ce574edd553f 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order @@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph = checkpointable_utils.object_metadata( - saver.latest_checkpoint(config.logdir)) + checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index ca6430253b67d825290b6a376ba3f29b3ae67577..de11d00a1a0a34372467eedb02d790c920e7f449 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -34,6 +34,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@run @@enable_eager_execution +@@enable_remote_eager_execution @@custom_gradient @@ -70,6 +71,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@run_test_in_graph_and_eager_modes @@run_all_tests_in_graph_and_eager_modes +@@TensorSpec + @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT @@ -113,7 +116,9 @@ from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution +from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 1aa3df8d8d8fc4bfd1f4543f3198f2581632a83f..349f48f7f788b458af2639f7ad4cc4cd904465b4 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -28,6 +28,7 @@ py_library( ":multi_head", ":replicate_model_fn", ":rnn", + ":saved_model_estimator", "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -465,3 +466,43 @@ py_test( "@absl_py//absl/testing:parameterized", ], ) + +py_library( + name = "saved_model_estimator", + srcs = ["python/estimator/saved_model_estimator.py"], + deps = [ + ":export", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model", + ], +) + +py_test( + name = "saved_model_estimator_test", + size = "medium", + srcs = ["python/estimator/saved_model_estimator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":export", + ":saved_model_estimator", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + ], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 09fcfd66a1f90d5a323b09472c4b7f5b1234ee35..e1453ae1d04ebd8d72f812b51480f0b05f7a5416 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -33,6 +33,8 @@ from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.contrib.estimator.python.estimator.rnn import * +from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import * +from tensorflow.python.estimator.export.export import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -70,6 +72,9 @@ _allowed_symbols = [ 'stop_if_higher_hook', 'stop_if_no_increase_hook', 'stop_if_no_decrease_hook', + 'build_raw_supervised_input_receiver_fn', + 'build_supervised_input_receiver_fn_from_input_fn', + 'SavedModelEstimator' ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index 43bfcffd790e7b3c716c3f70820851a8819af225..7ed77bcce6f00ed13e9952951800f1017d582f19 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -50,7 +50,8 @@ class _BoostedTreesEstimator(estimator.Estimator): tree_complexity=0., min_node_weight=0., config=None, - center_bias=False): + center_bias=False, + pruning_mode='none'): """Initializes a `BoostedTreesEstimator` instance. Args: @@ -89,13 +90,18 @@ class _BoostedTreesEstimator(estimator.Estimator): regression problems, the first node will return the mean of the labels. For binary classification problems, it will return a logit for a prior probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. """ # pylint:disable=protected-access # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -129,7 +135,8 @@ def boosted_trees_classifier_train_in_memory( min_node_weight=0., config=None, train_hooks=None, - center_bias=False): + center_bias=False, + pruning_mode='none'): """Trains a boosted tree classifier with in memory dataset. Example: @@ -208,6 +215,11 @@ def boosted_trees_classifier_train_in_memory( regression problems, the first node will return the mean of the labels. For binary classification problems, it will return a logit for a prior probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -228,7 +240,7 @@ def boosted_trees_classifier_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -269,7 +281,8 @@ def boosted_trees_regressor_train_in_memory( min_node_weight=0., config=None, train_hooks=None, - center_bias=False): + center_bias=False, + pruning_mode='none'): """Trains a boosted tree regressor with in memory dataset. Example: @@ -341,6 +354,11 @@ def boosted_trees_regressor_train_in_memory( regression problems, the first node will return the mean of the labels. For binary classification problems, it will return a logit for a prior probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -360,7 +378,7 @@ def boosted_trees_regressor_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index 999c2aa5e28242f996e12da3807a74c6acf31df9..b1581f37509b5dc2bec98942e88c024905f25d93 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -136,6 +136,49 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): eval_res = est.evaluate(input_fn=input_fn, steps=1) self.assertAllClose(eval_res['average_loss'], 0.614642) + def testTrainAndEvaluateEstimatorWithPrePruning(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + tree_complexity=0.001, + pruning_mode='pre') + + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + # We stop actually after 2*depth*n_trees steps (via a hook) because we still + # could not grow 2 trees of depth 5 (due to pre-pruning). + self._assert_checkpoint( + est.model_dir, global_step=21, finalized_trees=0, attempted_layers=21) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 3.83943) + + def testTrainAndEvaluateEstimatorWithPostPruning(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + tree_complexity=0.001, + pruning_mode='post') + + # It will stop after 10 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.37652) + def testInferEstimator(self): train_input_fn = _make_train_input_fn(is_classification=False) predict_input_fn = numpy_io.numpy_input_fn( @@ -231,6 +274,31 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryAndEvalAndInferWithPrePruning(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5, + pruning_mode='pre', + tree_complexity=0.01) + # We stop actually after 2*depth*n_trees steps (via a hook) because we still + # could not grow 1 trees of depth 5 (due to pre-pruning). + self._assert_checkpoint( + est.model_dir, global_step=11, finalized_trees=0, attempted_layers=11) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryWithDataset(self): train_input_fn = _make_train_input_fn_dataset(is_classification=True) predict_input_fn = numpy_io.numpy_input_fn( diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..ce98e9987ec728fadf170e56fe4bfe24fc9a0105 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py @@ -0,0 +1,449 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class that creates an Estimator from a SavedModel.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export as export_lib +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training_util + + +class SavedModelEstimator(estimator_lib.Estimator): + """Create an Estimator from a SavedModel. + + Only SavedModels exported with + `tf.contrib.estimator.export_all_saved_models()` or + `tf.estimator.Estimator.export_savedmodel()` are supported for this class. + + Example with `tf.estimator.DNNClassifier`: + + **Step 1: Create and train DNNClassifier.** + + ```python + feature1 = tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_vocabulary_list( + key='feature1', vocabulary_list=('green', 'yellow')), dimension=1) + feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0) + + classifier = tf.estimator.DNNClassifier( + hidden_units=[4,2], feature_columns=[feature1, feature2]) + + def input_fn(): + features = {'feature1': tf.constant(['green', 'green', 'yellow']), + 'feature2': tf.constant([3.5, 4.2, 6.1])} + label = tf.constant([1., 0., 0.]) + return tf.data.Dataset.from_tensors((features, label)).repeat() + + classifier.train(input_fn=input_fn, steps=10) + ``` + + **Step 2: Export classifier.** + First, build functions that specify the expected inputs. + + ```python + # During train and evaluation, both the features and labels should be defined. + supervised_input_receiver_fn = ( + tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + {'feature1': tf.placeholder(dtype=tf.string, shape=[None]), + 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])}, + tf.placeholder(dtype=tf.float32, shape=[None]))) + + # During predict mode, expect to receive a `tf.Example` proto, so a parsing + # function is used. + serving_input_receiver_fn = ( + tf.estimator.export.build_parsing_serving_input_receiver_fn( + tf.feature_column.make_parse_example_spec([feature1, feature2]))) + ``` + + Next, export the model as a SavedModel. A timestamped directory will be + created (for example `/tmp/export_all/1234567890`). + + ```python + # Option 1: Save all modes (train, eval, predict) + export_dir = tf.contrib.estimator.export_all_saved_models( + classifier, '/tmp/export_all', + {tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn, + tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn, + tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn}) + + # Option 2: Only export predict mode + export_dir = classifier.export_savedmodel( + '/tmp/export_predict', serving_input_receiver_fn) + ``` + + **Step 3: Create a SavedModelEstimator from the exported SavedModel.** + + ```python + est = tf.contrib.estimator.SavedModelEstimator(export_dir) + + # If all modes were exported, you can immediately evaluate and predict, or + # continue training. Otherwise only predict is available. + eval_results = est.evaluate(input_fn=input_fn, steps=1) + print(eval_results) + + est.train(input_fn=input_fn, steps=20) + + def predict_input_fn(): + example = tf.train.Example() + example.features.feature['feature1'].bytes_list.value.extend(['yellow']) + example.features.feature['feature2'].float_list.value.extend([1.]) + return {'inputs':tf.constant([example.SerializeToString()])} + + predictions = est.predict(predict_input_fn) + print(next(predictions)) + ``` + """ + + def __init__(self, saved_model_dir, model_dir=None): + """Initialize a SavedModelEstimator. + + The SavedModelEstimator loads its model function and variable values from + the graphs defined in the SavedModel. There is no option to pass in + `RunConfig` or `params` arguments, because the model function graph is + defined statically in the SavedModel. + + Args: + saved_model_dir: Directory containing SavedModel protobuf and subfolders. + model_dir: Directory to save new checkpoints during training. + + Raises: + NotImplementedError: If a DistributionStrategy is defined in the config. + Unless the SavedModelEstimator is subclassed, this shouldn't happen. + """ + checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access + vars_to_warm_start = [name for name, _ in + checkpoint_utils.list_variables(checkpoint)] + warm_start_settings = estimator_lib.WarmStartSettings( + ckpt_to_initialize_from=checkpoint, + vars_to_warm_start=vars_to_warm_start) + + super(SavedModelEstimator, self).__init__( + model_fn=self._model_fn_from_saved_model, model_dir=model_dir, + warm_start_from=warm_start_settings) + if self._train_distribution or self._eval_distribution: + raise NotImplementedError( + 'SavedModelEstimator currently does not support ' + 'DistributionStrategy.') + self.saved_model_dir = saved_model_dir + self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir) + self._available_modes = self._extract_available_modes() + + def _extract_available_modes(self): + """Return list of modes found in SavedModel.""" + available_modes = [] + logging.info('Checking available modes for SavedModelEstimator.') + for mode in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + try: + self._get_meta_graph_def_for_mode(mode) + except RuntimeError: + logging.warning('%s mode not found in SavedModel.' % mode) + continue + + if self._get_signature_def_for_mode(mode) is not None: + available_modes.append(mode) + + logging.info('Available modes for Estimator: %s' % available_modes) + return available_modes + + def _validate_mode(self, mode): + """Make sure that mode can be run using the SavedModel.""" + if mode not in self._available_modes: + raise RuntimeError('%s mode is not available in the SavedModel. Use ' + 'saved_model_cli to check that the Metagraph for this ' + 'mode has been exported.' % mode) + + def _get_meta_graph_def_for_mode(self, mode): + tags = model_fn_lib.EXPORT_TAG_MAP[mode] + return self.saved_model_loader.get_meta_graph_def_from_tags(tags) + + def _get_signature_def_for_mode(self, mode): + meta_graph_def = self._get_meta_graph_def_for_mode(mode) + sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + if mode == model_fn_lib.ModeKeys.PREDICT else mode) + if sig_def_key not in meta_graph_def.signature_def: + logging.warning('Metagraph for mode %s was found, but SignatureDef with' + ' key \"%s\" is missing.' % (mode, sig_def_key)) + return None + return meta_graph_def.signature_def[sig_def_key] + + def _create_and_assert_global_step(self, graph): + # Do nothing here. The global step variable will be created/loaded from the + # SavedModel. If a global step variable were created here, the result + # will be two duplicate global step variables, causing issues during + # the warm-start phase. + # Due to the global variable being created in the model function, this may + # cause issues when running DistributionStrategy. Thus, DistributionStrategy + # is not yet supported with SavedModelEstimator. + return None + + def _model_fn_from_saved_model(self, features, labels, mode): + """Load a SavedModel graph and return an EstimatorSpec.""" + # TODO(kathywu): Model function loads placeholders from the graph. Calling + # export_all_saved_models creates another placeholder for the inputs, on top + # of the original placeholders. There should be a way to avoid this. + self._validate_mode(mode) + + g = ops.get_default_graph() + if training_util.get_global_step(g) is not None: + raise RuntimeError( + 'Graph must not contain a global step tensor before the SavedModel is' + ' loaded. Please make sure that the input function does not create a ' + 'global step.') + + # Extract SignatureDef for information about the input and output tensors. + signature_def = self._get_signature_def_for_mode(mode) + + # Generate input map for replacing the inputs in the SavedModel graph with + # the provided features and labels. + input_map = _generate_input_map(signature_def, features, labels) + + # Create a list of the names of output tensors. When the graph is loaded, + # names of the output tensors may be remapped. This ensures that the correct + # tensors are returned in the EstimatorSpec. + output_tensor_names = [ + value.name for value in six.itervalues(signature_def.outputs)] + + # Load the graph. `output_tensors` contains output `Tensors` in the same + # same order as the `output_tensor_names` list. + tags = model_fn_lib.EXPORT_TAG_MAP[mode] + _, output_tensors = self.saved_model_loader.load_graph( + g, tags, input_map=input_map, return_elements=output_tensor_names) + + # Create a scaffold from the MetaGraphDef that contains ops to initialize + # the graph. This should mirror the steps from _add_meta_graph_for_mode(), + # which creates a MetaGraphDef from the EstimatorSpec's scaffold. + scaffold = monitored_session.Scaffold( + local_init_op=loader_impl._get_main_op_tensor( # pylint: disable=protected-access + self._get_meta_graph_def_for_mode(mode))) + + # Ensure that a global step tensor has been created. + global_step_tensor = training_util.get_global_step(g) + training_util.assert_global_step(global_step_tensor) + + # Extract values to return in the EstimatorSpec. + output_map = dict(zip(output_tensor_names, output_tensors)) + outputs = {key: output_map[value.name] + for key, value in six.iteritems(signature_def.outputs)} + + loss, predictions, metrics = _validate_and_extract_outputs( + mode, outputs, signature_def.method_name) + + train_op = ops.get_collection(constants.TRAIN_OP_KEY) + if len(train_op) > 1: + raise RuntimeError('Multiple ops found in the train_op collection.') + train_op = None if not train_op else train_op[0] + + _clear_saved_model_collections() + return model_fn_lib.EstimatorSpec( + scaffold=scaffold, + mode=mode, + loss=loss, + train_op=train_op, + predictions=predictions, + eval_metric_ops=metrics) + + +def _clear_saved_model_collections(): + """Clear collections that are expected empty when exporting a SavedModel. + + The SavedModel builder uses these collections to track ops necessary to + restore the graph state. These collections are expected to be empty before + MetaGraphs are added to the builder. + """ + del ops.get_collection_ref(constants.ASSETS_KEY)[:] + del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:] + del ops.get_collection_ref(constants.MAIN_OP_KEY)[:] + del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:] + + +def _generate_input_map(signature_def, features, labels): + """Return dict mapping an input tensor name to a feature or label tensor. + + Args: + signature_def: SignatureDef loaded from SavedModel + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. + labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the labels to be passed to the model. May be + `None`. + + Returns: + dict mapping string names of inputs to features or labels tensors + + Raises: + ValueError: if SignatureDef inputs are not completely mapped by the input + features and labels. + """ + # pylint: disable=protected-access + if not isinstance(features, dict): + features = {export_lib._SINGLE_FEATURE_DEFAULT_NAME: features} + if labels is not None and not isinstance(labels, dict): + labels = {export_lib._SINGLE_LABEL_DEFAULT_NAME: labels} + # pylint: enable=protected-access + + inputs = signature_def.inputs + input_map = {} + for key, tensor_info in six.iteritems(inputs): + input_name = tensor_info.name + if ':' in input_name: + input_name = input_name[:input_name.find(':')] + + # When tensors are used as control inputs for operations, their names are + # prepended with a '^' character in the GraphDef. To handle possible control + # flow edge cases, control input names must be included in the input map. + control_dependency_name = '^' + input_name + + if key in features: + _check_same_dtype_and_shape(features[key], tensor_info, key) + input_map[input_name] = input_map[control_dependency_name] = features[key] + elif labels is not None and key in labels: + _check_same_dtype_and_shape(labels[key], tensor_info, key) + input_map[input_name] = input_map[control_dependency_name] = labels[key] + else: + raise ValueError( + 'Key \"%s\" not found in features or labels passed in to the model ' + 'function. All required keys: %s' % (key, inputs.keys())) + + return input_map + + +def _check_same_dtype_and_shape(tensor, tensor_info, name): + """Validate that tensor has the same properties as the TensorInfo proto. + + Args: + tensor: a `Tensor` object. + tensor_info: a `TensorInfo` proto. + name: Name of the input (to identify Tensor if an error is raised). + + Raises: + ValueError: If the tensor shape or dtype don't match the TensorInfo + """ + dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype)) + shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape) + + if dtype_error or shape_error: + msg = 'Tensor shape and/or dtype validation failed for input %s:' % name + if dtype_error: + msg += ('\n\tExpected dtype: %s, Got: %s' + % (dtypes.DType(tensor_info.dtype), tensor.dtype)) + if shape_error: + msg += ('\n\tExpected shape: %s, Got: %s' + % (tensor_shape.TensorShape(tensor_info.tensor_shape), + tensor.shape)) + + raise ValueError(msg) + + +def _extract_eval_metrics(output_dict): + """Return a eval metric dict extracted from the output_dict. + + Eval metrics consist of a value tensor and an update op. Both must be in the + passed-in tensor dictionary for an eval metric to be added to the returned + dictionary. + + Args: + output_dict: a dict that maps strings to tensors. + + Returns: + dict mapping strings to (value, update_op) tuples. + """ + # pylint: disable=protected-access + metric_ops = {} + separator_char = export_output._SupervisedOutput._SEPARATOR_CHAR + + for key, tensor in six.iteritems(output_dict): + split_key = key.split(separator_char) + + # The metric name may contain the separator character, so recreate its name. + metric_name = separator_char.join(split_key[:-1]) + + if split_key[0] == export_output._SupervisedOutput.METRICS_NAME: + # If the key ends with the value suffix, and there is a corresponding + # key ending with the update_op suffix, then add tensors to metrics dict. + if split_key[-1] == export_output._SupervisedOutput.METRIC_VALUE_SUFFIX: + update_op = ''.join( + [metric_name, separator_char, + export_output._SupervisedOutput.METRIC_UPDATE_SUFFIX]) + if update_op in output_dict: + update_op_tensor = output_dict[update_op] + metric_ops[metric_name] = (tensor, update_op_tensor) + + # pylint: enable=protected-access + return metric_ops + + +def _validate_and_extract_outputs(mode, output_dict, method_name): + """Extract values from SignatureDef output dictionary. + + Args: + mode: One of the modes enumerated in `tf.estimator.ModeKeys`. + output_dict: dict of string SignatureDef keys to `Tensor`. + method_name: Method name of the SignatureDef as a string. + + Returns: + Tuple of ( + loss: `Tensor` object, + predictions: dictionary mapping string keys to `Tensor` objects, + metrics: dictionary mapping string keys to a tuple of two `Tensor` objects + ) + + Raises: + RuntimeError: raised if SignatureDef has an invalid method name for the mode + """ + # pylint: disable=protected-access + loss, predictions, metrics = None, None, None + + if mode == model_fn_lib.ModeKeys.PREDICT: + predictions = output_dict + else: + # Validate that the SignatureDef's method name matches the expected name for + # the given mode. + expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME + if mode == model_fn_lib.ModeKeys.EVAL: + expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME + if method_name != expected_method_name: + raise RuntimeError( + 'Invalid SignatureDef method name for mode %s.\n\tExpected: %s\n\t' + 'Got: %s\nPlease ensure that the SavedModel was exported with ' + '`tf.contrib.estimator.export_all_saved_models()`.' % + (mode, expected_method_name, method_name)) + + # Extract loss, metrics and predictions from the output dict. + loss = output_dict[export_output._SupervisedOutput.LOSS_NAME] + metrics = _extract_eval_metrics(output_dict) + predictions = { + key: value for key, value in six.iteritems(output_dict) + if key.split(export_output._SupervisedOutput._SEPARATOR_CHAR)[0] == ( + export_output._SupervisedOutput.PREDICTIONS_NAME)} + + # pylint: enable=protected-access + return loss, predictions, metrics diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..718da1367ce69285f37269c5631fa0be2b050c97 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py @@ -0,0 +1,369 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModelEstimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.contrib.estimator.python.estimator import saved_model_estimator +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training + + +def dummy_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)}, + constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat() + + +def dummy_input_fn_features_only(): + return dataset_ops.Dataset.from_tensors( + {'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat() + + +def dummy_supervised_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[2, 1], name='truth') + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +def dummy_serving_receiver_fn(): + feature_spec = {'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'),} + return export.build_raw_serving_input_receiver_fn(feature_spec) + + +def model_fn_diff_modes(features, labels, mode): + _, _ = features, labels + v = variables.Variable(21, name='some_var') + train_op = None + loss = constant_op.constant(104) + if mode == model_fn_lib.ModeKeys.TRAIN: + loss = constant_op.constant(105) + predictions = constant_op.constant([501]) + train_op = control_flow_ops.group( + state_ops.assign_add(training.get_global_step(), 1), + state_ops.assign_add(v, 3)) + elif mode == model_fn_lib.ModeKeys.EVAL: + loss = constant_op.constant(106) + predictions = constant_op.constant([502]) + else: + loss = constant_op.constant(107) + predictions = constant_op.constant([503]) + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=train_op, + eval_metric_ops={ + 'abs_err': metrics_lib.mean_absolute_error( + constant_op.constant(0), predictions)}, + predictions=predictions) + + +class SavedModelEstimatorTest(test.TestCase): + + def setUp(self): + self.tmpdirs = [] + + def tearDown(self): + for tmpdir in self.tmpdirs: + # gfile.DeleteRecursively fails in the windows cmake test, so use shutil. + shutil.rmtree(tmpdir, ignore_errors=True) + self.tmpdirs = [] + + def _get_tmp_dir(self): + tmpdir = tempfile.mkdtemp() + self.tmpdirs.append(tmpdir) + return tmpdir + + def _export_estimator(self, train=True, evaluate=True, predict=True, + model_fn=model_fn_diff_modes): + est = estimator.Estimator(model_fn, self._get_tmp_dir()) + est.train(input_fn=dummy_input_fn, steps=10) + + input_receiver_fn_map = {} + if train: + input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = ( + dummy_supervised_receiver_fn()) + if evaluate: + input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = ( + dummy_supervised_receiver_fn()) + if predict: + input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = ( + dummy_serving_receiver_fn()) + + export_base_path = self._get_tmp_dir() + export_dir = contrib_export.export_all_saved_models( + est, export_base_path, input_receiver_fn_map) + return export_dir + + def test_load_all_modes(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + sme.train(input_fn=dummy_input_fn, steps=1) + sme.train(input_fn=dummy_input_fn, steps=2) + self.assertEqual(13, sme.get_variable_value('global_step')) + self.assertEqual(60, sme.get_variable_value('some_var')) + + eval_results = sme.evaluate(dummy_input_fn, steps=5) + + self.assertEqual(13, eval_results['global_step']) + self.assertEqual(106, eval_results['loss']) + self.assertEqual(502, eval_results['metrics/abs_err']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_load_all_modes_no_train(self): + """Ensure that all functions can be used without requiring a ckpt.""" + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + eval_results = sme.evaluate(dummy_input_fn, steps=5) + self.assertEqual(10, eval_results['global_step']) + self.assertEqual(106, eval_results['loss']) + self.assertEqual(502, eval_results['metrics/abs_err']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_partial_exported_estimator(self): + sme1 = saved_model_estimator.SavedModelEstimator( + self._export_estimator(train=False, predict=False), self._get_tmp_dir()) + sme1.evaluate(dummy_input_fn, steps=5) + with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'): + sme1.train(input_fn=dummy_input_fn, steps=1) + with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'): + next(sme1.predict(dummy_input_fn_features_only)) + + sme2 = saved_model_estimator.SavedModelEstimator( + self._export_estimator(evaluate=False), self._get_tmp_dir()) + sme2.train(input_fn=dummy_input_fn, steps=1) + next(sme2.predict(dummy_input_fn_features_only)) + with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'): + sme2.evaluate(dummy_input_fn, steps=5) + + def test_with_incorrect_input(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + + def bad_shape_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([1, 2], dtype=dtypes.int64)}, + constant_op.constant([1, 2], dtype=dtypes.float32))) + + with self.assertRaisesRegexp(ValueError, 'Expected shape'): + sme.train(bad_shape_input_fn, steps=1) + + def bad_dtype_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)}, + constant_op.constant([[1], [1]], dtype=dtypes.int64))) + + with self.assertRaisesRegexp(ValueError, 'Expected dtype'): + sme.train(bad_dtype_input_fn, steps=1) + + def test_input_fn_with_global_step(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + + def bad_input_fn(): + training.get_or_create_global_step() + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)}, + constant_op.constant([[1], [1]], dtype=dtypes.float32))) + + with self.assertRaisesRegexp(RuntimeError, + 'Graph must not contain a global step tensor'): + sme.train(bad_input_fn, steps=1) + + def test_re_export_saved_model_serving_only(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + sme.train(dummy_input_fn, steps=3) + self.assertEqual(13, sme.get_variable_value('global_step')) + self.assertEqual(60, sme.get_variable_value('some_var')) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + # Export SavedModel, and test that the variable and prediction values are + # the same. + sme_export_dir = sme.export_savedmodel( + self._get_tmp_dir(), dummy_serving_receiver_fn()) + + sme2 = saved_model_estimator.SavedModelEstimator( + sme_export_dir, self._get_tmp_dir()) + self.assertEqual(60, sme.get_variable_value('some_var')) + self.assertEqual(13, sme.get_variable_value('global_step')) + + predictions = next(sme2.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_re_export_saved_model(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 10}, + sme.evaluate(dummy_input_fn, steps=1)) + + sme.train(dummy_input_fn, steps=3) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13}, + sme.evaluate(dummy_input_fn, steps=1)) + self.assertEqual(60, sme.get_variable_value('some_var')) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + # Export SavedModel for all modes + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()} + sme_export_dir = contrib_export.export_all_saved_models( + sme, self._get_tmp_dir(), input_receiver_fn_map) + + sme2 = saved_model_estimator.SavedModelEstimator( + sme_export_dir, self._get_tmp_dir()) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13}, + sme.evaluate(dummy_input_fn, steps=1)) + self.assertEqual(60, sme.get_variable_value('some_var')) + + sme.train(dummy_input_fn, steps=7) + self.assertEqual(20, sme.get_variable_value('global_step')) + + predictions = next(sme2.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_load_saved_model_from_serving_only(self): + def model_fn(features, labels, mode): + _, _ = features, labels + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant([103]), + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions=constant_op.constant([502]), + export_outputs={'test': export_output.ClassificationOutput( + constant_op.constant([[32.]]))}) + + est = estimator.Estimator(model_fn, self._get_tmp_dir()) + est.train(input_fn=dummy_input_fn, steps=10) + + def serving_input_receiver_fn(): + return export.ServingInputReceiver( + {'test-features': constant_op.constant([[1], [1]])}, + array_ops.placeholder(dtype=dtypes.string)) + + export_dir = est.export_savedmodel( + self._get_tmp_dir(), serving_input_receiver_fn) + + sme = saved_model_estimator.SavedModelEstimator( + export_dir, self._get_tmp_dir()) + + def input_fn(): + return {'inputs': constant_op.constant('someinputstr')} + + prediction = next(sme.predict(input_fn)) + self.assertDictEqual({'scores': 32}, prediction) + + def test_with_local_init_op(self): + def model_fn(features, labels, mode): + _, _ = features, labels + v = variables.Variable(21, name='some_var') + scaffold = monitored_session.Scaffold( + local_init_op=state_ops.assign_add(v, -3).op + ) + return model_fn_lib.EstimatorSpec( + mode, + scaffold=scaffold, + train_op=state_ops.assign_add(training.get_global_step(), 1), + loss=array_ops.identity(v)) + export_dir = self._export_estimator(predict=False, model_fn=model_fn) + sme = saved_model_estimator.SavedModelEstimator( + export_dir, self._get_tmp_dir()) + + eval_results1 = sme.evaluate(dummy_input_fn, steps=2) + self.assertEqual(15, eval_results1['loss']) + + sme.train(dummy_input_fn, steps=1) + self.assertEqual(15, sme.get_variable_value('some_var')) + + eval_results2 = sme.evaluate(dummy_input_fn, steps=5) + self.assertEqual(12, eval_results2['loss']) + + def test_with_working_input_fn(self): + def model_fn(features, labels, mode): + loss = None + if labels is not None: + loss = labels[0][0] + labels[1][0] + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions={'features_0': array_ops.identity([features['x'][0][0]]), + 'features_1': array_ops.identity([features['x'][1][0]])}) + + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(model_fn=model_fn), self._get_tmp_dir()) + eval_results = sme.evaluate(dummy_input_fn, steps=1) + self.assertEqual(1, eval_results['loss']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions) + + def test_control_dependency(self): + # Control dependencies are saved with "^" appended to the start of the input + # name. The input map must include control dependencies as well. + def model_fn(features, labels, mode): + _ = labels + with ops.control_dependencies([features['x']]): + loss = features['x'][1][0] + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=state_ops.assign_add(training.get_global_step(), 1)) + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(train=False, predict=False, model_fn=model_fn), + self._get_tmp_dir()) + sme.evaluate(dummy_input_fn, steps=1) # Should run without error + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index dc49383c5c300e82839c478e097074b3e8776b3b..918a7e2bc772dee226e5ef23d0e3e34309f180f4 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -133,6 +133,7 @@ _nest_allowed_symbols = [ 'flatten_dict_items', 'pack_sequence_as', 'map_structure', + 'map_structure_with_paths', 'assert_shallow_structure', 'flatten_up_to', 'map_structure_up_to', diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py index 9e356dd96562c28adec7fc28fe144394e1c2ed38..e7184a01fbf57319399fc6dd287b7387138b4058 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py @@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training as train __all__ = [ @@ -40,7 +40,7 @@ __all__ = [ def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): - return saver.latest_checkpoint(filepattern) + return checkpoint_management.latest_checkpoint(filepattern) return filepattern diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 7e6cb724853b3a0aaaefc1137373063fb78f5549..053d4e3e977ed1baed8ceeca1a983e999b1ad1ff 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -196,11 +196,16 @@ py_test( srcs = ["python/losses/python/tuple_losses_test.py"], srcs_version = "PY2AND3", deps = [ + ":losses_impl", ":namedtuples", ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 8e4affb9b4f95bf5afab0f50c86954e60a942279..ab9886580d1648852e08f64cb3e9b51f679c25de 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -53,9 +53,6 @@ _summary_type_map = { } -# TODO(joelshor): For now, this only supports 1:1 generator:discriminator -# training sequentially. Find a nice way to expose options to the user without -# exposing internals. class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index dcc3f94c2d6b9e5e44036e7cc1a9d1bb39104fb5..221c70c38bd432a6be7f6cda9c6700aa2255821f 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -80,6 +80,9 @@ __all__ = [ 'mutual_information_penalty', 'combine_adversarial_loss', 'cycle_consistency_loss', + 'stargan_generator_loss_wrapper', + 'stargan_discriminator_loss_wrapper', + 'stargan_gradient_penalty_wrapper' ] @@ -277,3 +280,86 @@ def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False): cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x, cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y, scope, add_summaries) + + +def stargan_generator_loss_wrapper(loss_fn): + """Convert a generator loss function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking Discriminator's real/fake prediction for + generated data. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + return loss_fn( + stargan_model.discriminator_generated_data_source_predication, **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn + + +def stargan_discriminator_loss_wrapper(loss_fn): + """Convert a discriminator loss function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking Discriminator's real/fake prediction for + real data and generated data. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + return loss_fn( + stargan_model.discriminator_input_data_source_predication, + stargan_model.discriminator_generated_data_source_predication, **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn + + +def stargan_gradient_penalty_wrapper(loss_fn): + """Convert a gradient penalty function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking real_data, generated_data, + generator_inputs for Discriminator's condition (i.e. number of domains), + discriminator_fn, and discriminator_scope. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + num_domains = stargan_model.input_data_domain_label.shape.as_list()[-1] + return loss_fn( + real_data=stargan_model.input_data, + generated_data=stargan_model.generated_data, + generator_inputs=num_domains, + discriminator_fn=stargan_model.discriminator_fn, + discriminator_scope=stargan_model.discriminator_scope, + **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index aa1ef11172dee6799994b87f70a3883cd67fd15b..a559bbfa11367afd7dfe6a72d2ce2cc9d7ba1f16 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -22,10 +22,15 @@ import collections import numpy as np +from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -129,6 +134,9 @@ manual_tests = [ 'mutual_information_penalty', 'wasserstein_gradient_penalty', 'cycle_consistency_loss', + 'stargan_generator_loss_wrapper', + 'stargan_discriminator_loss_wrapper', + 'stargan_gradient_penalty_wrapper' ] discriminator_keyword_args = { @@ -175,6 +183,112 @@ class CycleConsistencyLossTest(test.TestCase): self.assertNear(5.0, loss.eval(), 1e-5) +class StarGANLossWrapperTest(test.TestCase): + + def setUp(self): + + super(StarGANLossWrapperTest, self).setUp() + + self.input_data = array_ops.ones([1, 2, 2, 3]) + self.input_data_domain_label = constant_op.constant([[0, 1]]) + self.generated_data = array_ops.ones([1, 2, 2, 3]) + self.discriminator_input_data_source_predication = array_ops.ones([1]) + self.discriminator_generated_data_source_predication = array_ops.ones([1]) + + def _discriminator_fn(inputs, num_domains): + """Differentiable dummy discriminator for StarGAN.""" + hidden = layers.flatten(inputs) + output_src = math_ops.reduce_mean(hidden, axis=1) + output_cls = layers.fully_connected( + inputs=hidden, + num_outputs=num_domains, + activation_fn=None, + normalizer_fn=None, + biases_initializer=None) + return output_src, output_cls + + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + + self.model = namedtuples.StarGANModel( + input_data=self.input_data, + input_data_domain_label=self.input_data_domain_label, + generated_data=self.generated_data, + generated_data_domain_target=None, + reconstructed_data=None, + discriminator_input_data_source_predication=self. + discriminator_input_data_source_predication, + discriminator_generated_data_source_predication=self. + discriminator_generated_data_source_predication, + discriminator_input_data_domain_predication=None, + discriminator_generated_data_domain_predication=None, + generator_variables=None, + generator_scope=None, + generator_fn=None, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=_discriminator_fn) + + self.discriminator_fn = _discriminator_fn + self.discriminator_scope = dis_scope + + def test_stargan_generator_loss_wrapper(self): + """Test StarGAN generator loss wrapper.""" + loss_fn = tfgan_losses_impl.wasserstein_generator_loss + wrapped_loss_fn = tfgan_losses.stargan_generator_loss_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + self.discriminator_generated_data_source_predication) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + def test_stargan_discriminator_loss_wrapper(self): + """Test StarGAN discriminator loss wrapper.""" + loss_fn = tfgan_losses_impl.wasserstein_discriminator_loss + wrapped_loss_fn = tfgan_losses.stargan_discriminator_loss_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + self.discriminator_generated_data_source_predication, + self.discriminator_generated_data_source_predication) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + def test_stargan_gradient_penalty_wrapper(self): + """Test StaGAN gradient penalty wrapper. + + Notes: + The random interpolates are handled by given setting the reconstruction to + be the same as the input. + + """ + loss_fn = tfgan_losses_impl.wasserstein_gradient_penalty + wrapped_loss_fn = tfgan_losses.stargan_gradient_penalty_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + real_data=self.input_data, + generated_data=self.generated_data, + generator_inputs=self.input_data_domain_label.shape.as_list()[-1], + discriminator_fn=self.discriminator_fn, + discriminator_scope=self.discriminator_scope) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + if __name__ == '__main__': for loss_name in tfgan_losses.__all__: if loss_name in manual_tests: continue diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index df603d1f183ee65893d6933ed77c52f52e5181e9..03f52d214b5ac2fef075fb66018f88d2be5c1941 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -34,6 +34,7 @@ from __future__ import print_function from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import losses as tfgan_losses from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl from tensorflow.contrib.slim.python.slim import learning as slim_learning from tensorflow.contrib.training.python.training import training from tensorflow.python.framework import dtypes @@ -41,14 +42,17 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.distributions import distribution as ds from tensorflow.python.ops.losses import losses +from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util + __all__ = [ 'gan_model', 'infogan_model', @@ -751,6 +755,130 @@ def cyclegan_loss( return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) +def stargan_loss( + model, + generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper( + tfgan_losses_impl.wasserstein_generator_loss), + discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper( + tfgan_losses_impl.wasserstein_discriminator_loss), + gradient_penalty_weight=10.0, + gradient_penalty_epsilon=1e-10, + gradient_penalty_target=1.0, + gradient_penalty_one_sided=False, + reconstruction_loss_fn=losses.absolute_difference, + reconstruction_loss_weight=10.0, + classification_loss_fn=losses.softmax_cross_entropy, + classification_loss_weight=1.0, + classification_one_hot=True, + add_summaries=True): + """StarGAN Loss. + + The four major part can be found here: http://screen/tMRMBAohDYG. + + Args: + model: (StarGAN) Model output of the stargan_model() function call. + generator_loss_fn: The loss function on the generator. Takes a + `StarGANModel` named tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `StarGANModel` namedtuple. + gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per + the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to + turn off gradient penalty. + gradient_penalty_epsilon: (float) A small positive number added for + numerical stability when computing the gradient norm. + gradient_penalty_target: (float, or tf.float `Tensor`) The target value of + gradient norm. Defaults to 1.0. + gradient_penalty_one_sided: (bool) If `True`, penalty proposed in + https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. + reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm + and the function must conform to the `tf.losses` API. + reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0. + classification_loss_fn: The loss function on the discriminator's ability to + classify domain of the input. Default to one-hot softmax cross entropy + loss, and the function must conform to the `tf.losses` API. + classification_loss_weight: (float) Classification loss weight. Default to + 1.0. + classification_one_hot: (bool) If the label is one hot representation. + Default to True. If False, classification classification_loss_fn need to + be sigmoid cross entropy loss instead. + add_summaries: (bool) Add the loss to the summary + + Returns: + GANLoss namedtuple where we have generator loss and discriminator loss. + + Raises: + ValueError: If input StarGANModel.input_data_domain_label does not have rank + 2, or dimension 2 is not defined. + """ + + def _classification_loss_helper(true_labels, predict_logits, scope_name): + """Classification Loss Function Helper. + + Args: + true_labels: Tensor of shape [batch_size, num_domains] representing the + label where each row is an one-hot vector. + predict_logits: Tensor of shape [batch_size, num_domains] representing the + predicted label logit, which is UNSCALED output from the NN. + scope_name: (string) Name scope of the loss component. + + Returns: + Single scalar tensor representing the classification loss. + """ + + with ops.name_scope(scope_name, values=(true_labels, predict_logits)): + + loss = classification_loss_fn( + onehot_labels=true_labels, logits=predict_logits) + + if not classification_one_hot: + loss = math_ops.reduce_sum(loss, axis=1) + loss = math_ops.reduce_mean(loss) + + if add_summaries: + summary.scalar(scope_name, loss) + + return loss + + # Check input shape. + model.input_data_domain_label.shape.assert_has_rank(2) + model.input_data_domain_label.shape[1:].assert_is_fully_defined() + + # Adversarial Loss. + generator_loss = generator_loss_fn(model, add_summaries=add_summaries) + discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries) + + # Gradient Penalty. + if _use_aux_loss(gradient_penalty_weight): + gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper( + tfgan_losses_impl.wasserstein_gradient_penalty) + discriminator_loss += gradient_penalty_fn( + model, + epsilon=gradient_penalty_epsilon, + target=gradient_penalty_target, + one_sided=gradient_penalty_one_sided, + add_summaries=add_summaries) * gradient_penalty_weight + + # Reconstruction Loss. + reconstruction_loss = reconstruction_loss_fn(model.input_data, + model.reconstructed_data) + generator_loss += reconstruction_loss * reconstruction_loss_weight + if add_summaries: + summary.scalar('reconstruction_loss', reconstruction_loss) + + # Classification Loss. + generator_loss += _classification_loss_helper( + true_labels=model.generated_data_domain_target, + predict_logits=model.discriminator_generated_data_domain_predication, + scope_name='generator_classification_loss') * classification_loss_weight + discriminator_loss += _classification_loss_helper( + true_labels=model.input_data_domain_label, + predict_logits=model.discriminator_input_data_domain_predication, + scope_name='discriminator_classification_loss' + ) * classification_loss_weight + + return namedtuples.GANLoss(generator_loss, discriminator_loss) + + def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): """Gets generator and discriminator update ops. diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index fa52e9cca14ecbe91c3fee1058db59a7e9447cee..58f348034fdcaadd8d738517aef2a7e2f0172c13 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -114,6 +114,12 @@ def stargan_generator_model(inputs, _): return variable_scope.get_variable('dummy_g', initializer=0.5) * inputs +class StarGANGenerator(object): + + def __call__(self, inputs, _): + return stargan_generator_model(inputs, _) + + def stargan_discriminator_model(inputs, num_domains): """Differentiable dummy discriminator for StarGAN.""" @@ -130,6 +136,12 @@ def stargan_discriminator_model(inputs, num_domains): return output_src, output_cls +class StarGANDiscriminator(object): + + def __call__(self, inputs, num_domains): + return stargan_discriminator_model(inputs, num_domains) + + def get_gan_model(): # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: @@ -272,6 +284,49 @@ def create_callable_cyclegan_model(): data_y=array_ops.ones([1, 2])) +def get_stargan_model(): + """Similar to get_gan_model().""" + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + pass + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + return namedtuples.StarGANModel( + input_data=array_ops.ones([1, 2, 2, 3]), + input_data_domain_label=array_ops.ones([1, 2]), + generated_data=array_ops.ones([1, 2, 2, 3]), + generated_data_domain_target=array_ops.ones([1, 2]), + reconstructed_data=array_ops.ones([1, 2, 2, 3]), + discriminator_input_data_source_predication=array_ops.ones([1]), + discriminator_generated_data_source_predication=array_ops.ones([1]), + discriminator_input_data_domain_predication=array_ops.ones([1, 2]), + discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), + generator_variables=None, + generator_scope=gen_scope, + generator_fn=stargan_generator_model, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=stargan_discriminator_model) + + +def get_callable_stargan_model(): + model = get_stargan_model() + return model._replace( + generator_fn=StarGANGenerator(), discriminator_fn=StarGANDiscriminator()) + + +def create_stargan_model(): + return train.stargan_model( + stargan_generator_model, stargan_discriminator_model, + array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2])) + + +def create_callable_stargan_model(): + return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(), + array_ops.ones([1, 2, 2, 3]), + array_ops.ones([1, 2])) + + def get_sync_optimizer(): return sync_replicas_optimizer.SyncReplicasOptimizer( gradient_descent.GradientDescentOptimizer(learning_rate=1.0), @@ -292,6 +347,8 @@ class GANModelTest(test.TestCase, parameterized.TestCase): ('cyclegan', get_cyclegan_model, namedtuples.CycleGANModel), ('callable_cyclegan', get_callable_cyclegan_model, namedtuples.CycleGANModel), + ('stargan', get_stargan_model, namedtuples.StarGANModel), + ('callabel_stargan', get_callable_stargan_model, namedtuples.StarGANModel) ) def test_output_type(self, create_fn, expected_tuple_type): """Test that output type is as expected.""" @@ -608,6 +665,27 @@ class GANLossTest(test.TestCase, parameterized.TestCase): self.assertTrue(np.isscalar(loss_y2x_gen_np)) self.assertTrue(np.isscalar(loss_y2x_dis_np)) + @parameterized.named_parameters( + ('notcallable', create_stargan_model), + ('callable', create_callable_stargan_model), + ) + def test_stargan(self, create_gan_model_fn): + + model = create_gan_model_fn() + model_loss = train.stargan_loss(model) + + self.assertIsInstance(model_loss, namedtuples.GANLoss) + + with self.test_session() as sess: + + sess.run(variables.global_variables_initializer()) + + gen_loss, disc_loss = sess.run( + [model_loss.generator_loss, model_loss.discriminator_loss]) + + self.assertTrue(np.isscalar(gen_loss)) + self.assertTrue(np.isscalar(disc_loss)) + @parameterized.named_parameters( ('gan', create_gan_model), ('callable_gan', create_callable_gan_model), diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index f3bbf6b4d78b50b11e23abd584bacff8f3d877c7..7e6a0f14f6f5e467801fef39ebb597565b3d7e98 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -174,7 +174,7 @@ class GdrMemoryManager : public RemoteMemoryManager { // Client side endpoints mutex client_mu_; std::map, RdmaEndpointPtr> clients_ - GUARDED_BY(cient_mu_); + GUARDED_BY(client_mu_); // Managed memory regions mutex alloc_mu_; diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index 1f9e82b41bf09b235e93fa512a50ea4c3047c01b..cb649a37510c301cb3df997f844617e9a4e6c7be 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,10 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras.preprocessing.image import apply_transform from tensorflow.python.keras.preprocessing.image import array_to_img from tensorflow.python.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras.preprocessing.image import flip_axis from tensorflow.python.keras.preprocessing.image import ImageDataGenerator from tensorflow.python.keras.preprocessing.image import img_to_array from tensorflow.python.keras.preprocessing.image import Iterator diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index bc3359693562deb1229a78a2db5c256c76f7fd8d..a7b41b714ffaa062e2eba8caf9b4fa033c7633cd 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -121,6 +121,7 @@ from tensorflow.contrib.layers.python.layers import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['bias_add', + 'conv1d', 'conv2d', 'conv3d', 'elu', diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index dd602cf3a9b7826a19408a78ef543bb0c4fbf84e..6250f8852917c00c94162ce9711bb8f34051565b 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -55,9 +55,9 @@ from tensorflow.python.training import moving_averages # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. __all__ = [ - 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', - 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', - 'convolution1d', 'convolution2d', 'convolution2d_in_plane', + 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv1d', 'conv2d', + 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', + 'convolution', 'convolution1d', 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', @@ -2660,7 +2660,7 @@ def separable_convolution2d( inputs, num_outputs, kernel_size, - depth_multiplier, + depth_multiplier=1, stride=1, padding='SAME', data_format=DATA_FORMAT_NHWC, @@ -3320,6 +3320,7 @@ relu6 = functools.partial(fully_connected, activation_fn=nn.relu6) linear = functools.partial(fully_connected, activation_fn=None) # Simple alias. +conv1d = convolution1d conv2d = convolution2d conv3d = convolution3d conv2d_transpose = convolution2d_transpose diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 7a026a15e4aeea0dde4ed9f7de053a757a0abb58..c1de42782efb3497660affb3ef7162457977c150 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary as core_summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import device_setter from tensorflow.python.training import monitored_session from tensorflow.python.training import saver @@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, as_iterable=True, iterate_batches=False): # Check that model has been trained. - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator): if not checkpoint_path: # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 7cb87619d960a03f342c7441730aaf2c4f15eb38..c36879e0483c92db0cc08dedbb483bcc288d4894 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -302,6 +302,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): # so instead of breaking compatibility with that assumption, we # just manually initialize this field: self._train_distribute = None + self._eval_distribute = None self._device_fn = None gpu_options = config_pb2.GPUOptions( diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index f8a3709ee57a32734afa7ac8133271c75d152b2c..08e907a608b0c6df6e7ac9d9675f7f9e2b84ff5d 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import function_utils @@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): # Load and cache the path of the most recent checkpoint to avoid duplicate # searches on GCS. logging.info("Checking for checkpoint in %s", self._model_dir) - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: logging.warning("Skipping evaluation and export since model has not been " @@ -516,7 +516,8 @@ class Experiment(object): start = time.time() error_msg = None - latest_path = saver.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if not latest_path: error_msg = ("Estimator is not fitted yet. " "Will start an evaluation when a checkpoint is ready.") @@ -778,7 +779,8 @@ class Experiment(object): saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + latest_checkpoint = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) eval_result = self._call_evaluate( input_fn=self._eval_input_fn, steps=self._eval_steps, diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py index 0d039d593b7850ead34484f88426255dc659b7fc..df156da3f467538ed1c6b640d651fdfd33ce243d 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib @@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path @@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 77f7c73d5412d40b338eaff4cf04d99fd0892723..3d691d434044aab1e3e86457cee6aadb5bf798c7 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as core_summary -from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.util import deprecation @@ -735,7 +735,8 @@ class ValidationMonitor(EveryN): return False self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) @@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN): def end(self, session=None): super(ExportMonitor, self).end(session=session) - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index 5c34d0ddb01f3bcdc407e6926e7c5b73be1863b4..ff1da32c218b4e105b5503426ac01410665f9c7e 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -39,9 +39,9 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver from tensorflow.python.training import training_util @@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase): mock_latest_checkpoint.assert_called_with(model_dir) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_early_stopping_rounds(self, mock_latest_checkpoint, mock_estimator_class): @@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase): self._assert_validation_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase): monitor.epoch_end(epoch=0) monitor.end() - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) model_dir = 'model/dir' @@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase): expected_best_metrics={'loss': 42.0, 'auc': 0.5}) monitor.post_step(step=step, session=None) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_fail_with_core_estimator_and_metrics( self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 3eacac7a3d3dcff4d39025fdee88e16e385b1b84..0144b93814a174cfb8c3162f407a595ac637f4f5 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import training_util @@ -298,7 +299,8 @@ def _export_estimator(estimator, # If checkpoint_path is specified, use the specified checkpoint path. checkpoint_path = (checkpoint_path or - tf_saver.latest_checkpoint(estimator._model_dir)) + checkpoint_management.latest_checkpoint( + estimator._model_dir)) with ops.Graph().as_default() as g: training_util.create_global_step(g) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index f8106d1e4a7e79f1cd651c40995be480721a8129..66af6833da1644fa4f73a24987079c9ffc8cecce 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated @@ -714,7 +714,8 @@ def make_best_model_export_strategy( # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index fe0ba19fcbe90edbeb1445e1fea77c36cf3ba170..7534b50a4ae0076fb27fb9cd0d1dd58b29192876 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -41,7 +41,10 @@ py_test( size = "medium", srcs = ["python/kernel_tests/sdca_ops_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows_gpu"], + tags = [ + "no_gpu", + "no_pip_gpu", + ], deps = [ ":sdca_ops_py", ":sparse_feature_column_py", diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 7d7dd6b7088f457b1a14a3ff30b7eef98c00d18a..1e6f1e7da212c3aeb1563dc2f4b6dff2cb550736 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -125,10 +125,22 @@ cc_library( "graph_info.cc", "interpreter.cc", "model.cc", - "nnapi_delegate.cc", "op_resolver.cc", "optional_debug_tools.cc", - ], + ] + select({ + "//tensorflow:android": [ + "nnapi_delegate.cc", + "mmap_allocation.cc", + ], + "//tensorflow:windows": [ + "nnapi_delegate_disabled.cc", + "mmap_allocation_disabled.cc", + ], + "//conditions:default": [ + "nnapi_delegate_disabled.cc", + "mmap_allocation.cc", + ], + }), hdrs = [ "allocation.h", "context.h", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index df5954744a41191d922e91553303e052969c24fb..9cc8f10b4290030898cffa8a8cac6ba395a30e2e 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -95,6 +95,7 @@ ARFLAGS := -r INCLUDES := \ -I. \ -I$(MAKEFILE_DIR)/../../../ \ +-I$(MAKEFILE_DIR)/../../../../ \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ @@ -176,8 +177,12 @@ $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ $(MINIMAL_SRCS) ifeq ($(BUILD_TYPE),micro) CORE_CC_EXCLUDE_SRCS += \ -tensorflow/contrib/lite/model.cc \ +tensorflow/contrib/lite/mmap_allocation.cc \ tensorflow/contrib/lite/nnapi_delegate.cc +else +CORE_CC_EXCLUDE_SRCS += \ +tensorflow/contrib/lite/mmap_allocation_disabled.cc \ +tensorflow/contrib/lite/nnapi_delegate_disabled.cc endif # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) @@ -214,8 +219,12 @@ all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY) # The target that's compiled for micro-controllers micro: $(LIB_PATH) +# Hack for generating schema file bypassing flatbuffer parsing +tensorflow/contrib/lite/schema/schema_generated.h: + @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h + # Gathers together all the objects we've compiled into a single '.a' archive. -$(LIB_PATH): $(LIB_OBJS) +$(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc index c42622ff02fc2837b61b35f19e834276c0518d1e..89462618148a2afbcf2ef6b1dd2985bcd0178734 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/contrib/lite/allocation.cc @@ -13,61 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#ifndef TFLITE_MCU -#include -#endif +#include "tensorflow/contrib/lite/allocation.h" + #include #include -#include #include #include #include #include #include -#include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" -#ifndef TFLITE_MCU -#include "tensorflow/contrib/lite/nnapi_delegate.h" -#endif namespace tflite { #ifndef TFLITE_MCU -MMAPAllocation::MMAPAllocation(const char* filename, - ErrorReporter* error_reporter) - : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { - mmap_fd_ = open(filename, O_RDONLY); - if (mmap_fd_ == -1) { - error_reporter_->Report("Could not open '%s'.", filename); - return; - } - struct stat sb; - fstat(mmap_fd_, &sb); - buffer_size_bytes_ = sb.st_size; - mmapped_buffer_ = - mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); - if (mmapped_buffer_ == MAP_FAILED) { - error_reporter_->Report("Mmap of '%s' failed.", filename); - return; - } -} - -MMAPAllocation::~MMAPAllocation() { - if (valid()) { - munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); - } - if (mmap_fd_ != -1) close(mmap_fd_); -} - -const void* MMAPAllocation::base() const { return mmapped_buffer_; } - -size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } - -bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } - FileCopyAllocation::FileCopyAllocation(const char* filename, ErrorReporter* error_reporter) : Allocation(error_reporter) { @@ -99,7 +60,9 @@ FileCopyAllocation::FileCopyAllocation(const char* filename, filename); return; } - copied_buffer_ = std::move(buffer); + // Versions of GCC before 6.2.0 don't support std::move from non-const + // char[] to const char[] unique_ptrs. + copied_buffer_.reset(const_cast(buffer.release())); } FileCopyAllocation::~FileCopyAllocation() {} @@ -109,6 +72,7 @@ const void* FileCopyAllocation::base() const { return copied_buffer_.get(); } size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; } +#endif MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, ErrorReporter* error_reporter) @@ -116,7 +80,6 @@ MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, buffer_ = ptr; buffer_size_bytes_ = num_bytes; } -#endif MemoryAllocation::~MemoryAllocation() {} diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h index 827ea86503f910714971e2b138295b9a5809dfd5..121f3d264687933f45f3a2c5d2a53ad80d594ca9 100644 --- a/tensorflow/contrib/lite/allocation.h +++ b/tensorflow/contrib/lite/allocation.h @@ -52,6 +52,8 @@ class MMAPAllocation : public Allocation { size_t bytes() const override; bool valid() const override; + static bool IsSupported(); + protected: // Data required for mmap. int mmap_fd_ = -1; // mmap file descriptor diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 7c13f9011ea691670dd014dda7cb92b44b5dc681..81844756bc7239fa798ff96b8b093afdf9ea9557 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -2,8 +2,8 @@ load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", "tf_cc_shared_object", + "tf_cc_test", ) def tflite_copts(): @@ -27,6 +27,9 @@ def tflite_copts(): str(Label("//tensorflow:ios_x86_64")): [ "-msse4.1", ], + str(Label("//tensorflow:windows")): [ + "/DTF_COMPILE_LIBRARY", + ], "//conditions:default": [], }) + select({ str(Label("//tensorflow:with_default_optimizations")): [], @@ -53,6 +56,7 @@ def tflite_linkopts_unstripped(): "-Wl,--gc-sections", # Eliminate unused code and data. "-Wl,--as-needed", # Don't link unused libs. ], + "//tensorflow:darwin": [], "//tensorflow/contrib/lite:mips": [], "//tensorflow/contrib/lite:mips64": [], "//conditions:default": [ @@ -74,6 +78,7 @@ def tflite_jni_linkopts_unstripped(): "-Wl,--gc-sections", # Eliminate unused code and data. "-Wl,--as-needed", # Don't link unused libs. ], + "//tensorflow:darwin": [], "//tensorflow/contrib/lite:mips": [], "//tensorflow/contrib/lite:mips64": [], "//conditions:default": [ @@ -122,19 +127,21 @@ def tflite_jni_binary( linkopts = linkopts, ) -def tflite_cc_shared_object(name, - copts=tflite_copts(), - linkopts=[], - linkstatic=1, - deps=[]): - """Builds a shared object for TFLite.""" - tf_cc_shared_object( - name=name, - copts=copts, - linkstatic=linkstatic, - linkopts=linkopts + tflite_jni_linkopts(), - framework_so=[], - deps=deps) +def tflite_cc_shared_object( + name, + copts = tflite_copts(), + linkopts = [], + linkstatic = 1, + deps = []): + """Builds a shared object for TFLite.""" + tf_cc_shared_object( + name = name, + copts = copts, + linkstatic = linkstatic, + linkopts = linkopts + tflite_jni_linkopts(), + framework_so = [], + deps = deps, + ) def tf_to_tflite(name, src, options, out): """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. @@ -240,6 +247,9 @@ def generated_test_models(): "local_response_norm", "log_softmax", "log", + "logical_and", + "logical_or", + "logical_xor", "lstm", "max_pool", "maximum", @@ -248,6 +258,7 @@ def generated_test_models(): "mul", "neg", "not_equal", + "one_hot", "pack", "pad", "padv2", diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index fd16aa1063ed6fb17144ede3efdb09cd17cd816e..70178b2faabe85f8a53a94c2b5d2e3ea40c8ba05 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -282,6 +282,10 @@ typedef struct { int axis; } TfLitePackParams; +typedef struct { + int axis; +} TfLiteOneHotParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 1ae73b97386ae67ede4535a6ecbf6bf600549046..8a8eb9856886538a1483141ab5f67f54613ea2a1 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -110,6 +110,9 @@ typedef enum { kTfLiteBuiltinReduceMax = 82, kTfLiteBuiltinPack = 83, kTfLiteBuiltinLogicalOr = 84, + kTfLiteBuiltinOneHot = 85, + kTfLiteBuiltinLogicalAnd = 86, + kTfLiteBuiltinLogicalNot = 87, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index cbfce12d7e5f994088b3d9b951aa2d59d6630605..5bc20106d31357e2da3f005baee0f8d134d37be2 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -29,6 +29,9 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#if defined(_MSC_VER) +#include +#endif #include #include #include @@ -180,7 +183,11 @@ typedef union { uint8_t* uint8; bool* b; int16_t* i16; +#if defined(_MSC_VER) + _Fcomplex* c64; +#else _Complex float* c64; +#endif } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index 03a4b7bf1dd8c51dcc26df5caeacc0ad3653339e..f21540d524e77b16a5a9fe1b66781eb6faeddd39 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -38,6 +38,42 @@ cc_test( ], ) +cc_library( + name = "delegate", + srcs = [ + "delegate.cc", + ], + hdrs = [ + "delegate.h", + ], + deps = [ + ":buffer_map", + ":delegate_data", + ":kernel", + ":util", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "delegate_test", + size = "small", + srcs = ["delegate_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], + deps = [ + ":delegate", + ":test_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "delegate_data", srcs = ["delegate_data.cc"], @@ -67,11 +103,59 @@ cc_test( ], ) +cc_library( + name = "kernel", + srcs = ["kernel.cc"], + hdrs = ["kernel.h"], + deps = [ + ":delegate_data", + ":util", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "@flatbuffers", + ], +) + +cc_test( + name = "kernel_test", + size = "small", + srcs = ["kernel_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], + deps = [ + ":delegate_data", + ":kernel", + ":test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", + "@flatbuffers", + ], +) + cc_library( name = "util", srcs = ["util.cc"], hdrs = ["util.h"], deps = [ + ":constants", "//tensorflow/c:c_api_internal", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", @@ -95,3 +179,8 @@ cc_test( "@com_google_googletest//:gtest", ], ) + +cc_library( + name = "constants", + hdrs = ["constants.h"], +) diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc index 1d6453f498a9474697240843ff8aff0d830e6a4f..e5a19c39976969a0b05b28596c6d7d5ebe7c7782 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc @@ -91,6 +91,10 @@ void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) { for (int i = 0; i < num_dims; ++i) { shape.AddDim(tensor->dims->data[i]); } + // TODO(ahentz): we assume this is a new tensor and allocate a new buffer + // for it. This is not always the best approach. For example, this might + // be a reallocation after resizing tensors. In that case we would be + // preferable to somehow reuse the buffer. auto* buf = new TfLiteTensorBuffer(tensor); tensorflow::Tensor t = tensorflow::TensorCApi::MakeTensor( GetTensorFlowDataType(tensor->type), shape, buf); diff --git a/tensorflow/contrib/lite/delegates/eager/constants.h b/tensorflow/contrib/lite/delegates/eager/constants.h new file mode 100644 index 0000000000000000000000000000000000000000..7ed6ab7552792c68e6d90056c83c3c574c3f69f7 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/constants.h @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ + +namespace tflite { +namespace eager { + +// The prefix of Eager op custom code. +// This will be matched agains the `custom_code` field in `OperatorCode` +// Flatbuffer Table. +constexpr char kCustomCodePrefix[] = "Eager"; + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d22b454199e2c0d9b8fea05086a7c62d7cdbe81 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" + +#include + +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { +namespace delegate { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Get the nodes in the current execution plan. Interpreter owns this array. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + + // Add all custom ops starting with "Eager" to list of supported nodes. + std::vector supported_nodes; + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + + if (IsEagerOp(registration->custom_name)) { + supported_nodes.push_back(node_index); + } + } + + // Request TFLite to partition the graph and make kernels for each independent + // subgraph. + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + context->ReplaceSubgraphsWithDelegateKernels(context, GetKernel(), + size_and_nodes, delegate); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) { + // TODO(nupurgarg): Make BufferMap unique to each interpreter in order to + // support multiple interpreters using a single delegate. + BufferMap* buffer_map = + reinterpret_cast(delegate->data_)->GetBufferMap(); + + // TODO(nupurgarg): Use TfLiteContext's ReportError instead of fprinf. + if (!buffer_map->HasTensor(buffer_handle)) { + fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle); + return kTfLiteError; + } + + tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle); + tensorflow::StringPiece t_data = t.tensor_data(); + + if (size != t_data.size()) { + fprintf(stderr, "Not enough space to store TensorFlow's aligned buffer.\n"); + return kTfLiteError; + } + + memcpy(data, t_data.data(), t_data.size()); + return kTfLiteOk; +} + +} // namespace delegate +} // namespace eager + +EagerDelegate::EagerDelegate() {} + +EagerDelegate::~EagerDelegate() {} + +TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) { + if (!delegate_) { + if (!eager::DelegateData::Create(&delegate_data_).ok()) { + fprintf(stderr, "Unable to initialize TensorFlow context.\n"); + return kTfLiteError; + } + + delegate_.reset(new TfLiteDelegate{ + /*data_=*/delegate_data_.get(), + /*nullptr,*/ &eager::delegate::Prepare, + /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, + /*CopyToBufferHandle=*/nullptr, + /*FreeBufferHandle=*/nullptr}); + } + + return interpreter->ModifyGraphWithDelegate(delegate_.get(), + /*allow_dynamic_tensors=*/true); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..0defca7c323e81bfb211ac56fd59c8656b320574 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { + +// WARNING: This is an experimental interface that is subject to change. +// Delegate that can be used to extract parts of a graph that are designed to be +// executed by TensorFlow's runtime via Eager. +// +// The interpreter must be constructed after the EagerDelegate and destructed +// before the EagerDelegate. This delegate can only be used with one +// interpreter. +// +// Usage: +// EagerDelegate delegate; +// ... build interpreter ... +// +// delegate.Apply(interpreter); +// ... run inference ... +// ... destroy interpreter ... +// ... destroy delegate ... +class EagerDelegate { + public: + EagerDelegate(); + ~EagerDelegate(); + + // Modifies the graph loaded in the interpreter. + TfLiteStatus Apply(Interpreter* interpreter); + + private: + std::unique_ptr delegate_data_; + std::unique_ptr delegate_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc index 29687694bd0fba6b496f9b24c630d5929756ed83..0fd5c976f8ca9be16f7e3c5e610573755b40c506 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc @@ -23,7 +23,8 @@ tensorflow::Status DelegateData::Create(std::unique_ptr* data) { std::vector devices; TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( - tensorflow::SessionOptions(), "/device:cpu:*", &devices)); + tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0", + &devices)); std::unique_ptr device_mgr( new tensorflow::DeviceMgr(devices)); diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..88fb34044ec5f8e5b4593638163cd4e6407bf8c8 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc @@ -0,0 +1,150 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" + +#include +#include +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +namespace tflite { +namespace eager { +namespace { + +using ::testing::ContainsRegex; +using ::testing::ElementsAre; + +// TODO(nupurgarg): Add a test with multiple interpreters for one delegate. + +class DelegateTest : public testing::EagerModelTest { + public: + DelegateTest() { + // The delegate needs to be constructed before the interpreter because the + // interpreter references data contained in the delegate. + delegate_.reset(new EagerDelegate()); + interpreter_.reset(new Interpreter(&error_reporter_)); + } + + ~DelegateTest() override { + // The delegate needs to be destructed after the interpreter because the + // interpreter references data contained in the delegate. + delete interpreter_.release(); + delete delegate_.release(); + } + + void ConfigureDelegate() { + CHECK(delegate_->Apply(interpreter_.get()) == kTfLiteOk); + } + + private: + std::unique_ptr delegate_; +}; + +TEST_F(DelegateTest, FullGraph) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(DelegateTest, MixedGraph) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfLiteMulOp({6, 7}, {8}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(DelegateTest, SplitGraph) { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + + AddTfLiteMulOp({4, 5}, {6}); + + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(9), ElementsAre(1)); + ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); +} + +TEST_F(DelegateTest, OnlyTFLite) { + // Only TFLite single op model. + AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3}); + AddTfLiteMulOp({0, 1}, {2}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(1, {2, 2, 1}); + SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); + ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f)); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..172798180762f87e1c080be7788db661a63208b5 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -0,0 +1,289 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" + +#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/node_def.pb.h" + +// Note: this is part of TF Lite's Eager delegation code which is to be +// completed soon. + +// This is the TF Lite op that is created by the eager delegate to handle +// execution of a supported subgraph. The usual flow is that the delegate +// informs the interpreter of supported nodes in a graph, and each supported +// subgraph is replaced with one instance of this kernel. +// +// The kernel is initialized with TfLiteDelegateParams from which we retrieve +// the global EagerContext and BufferMap, as well as a list of inputs and +// outputs to the subgraph. Those are used to build the OpData, with a list of +// TensorFlow Ops that should be executed in order (which we call an OpNode). +// +// For each node included in the subgraph, we query the interpreter and +// retrieve the associated NodeDef, which is then used to configure the +// corresponding TensorFlow/Eager Op. + +namespace tflite { +namespace eager { +namespace kernel { + +// Controls the lifetime of tensor handles in a vector. +class VectorOfHandles { + public: + explicit VectorOfHandles(int num_elements) : vector_(num_elements, nullptr) {} + + ~VectorOfHandles() { + for (auto* handle : vector_) { + if (handle) handle->Unref(); + } + } + + tensorflow::gtl::InlinedVector* GetVector() { + return &vector_; + } + + tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; } + + private: + tensorflow::gtl::InlinedVector vector_; +}; + +// Executes the TensorFlow op given by 'op_name', with the attributes specified +// in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'. +tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context, + BufferMap* buffer_map, const string& op_name, + const tensorflow::NodeDef& nodedef, + const std::vector& inputs, + const std::vector& outputs) { + const tensorflow::AttrTypeMap* attr_types; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types), + " (while processing attributes of '", op_name, "')"); + + tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types); + for (const auto& attr : nodedef.attr()) { + op.MutableAttrs()->Set(attr.first, attr.second); + } + + for (int input_index : inputs) { + if (!buffer_map->HasTensor(input_index)) { + return tensorflow::errors::Internal( + "Cannot read from invalid tensor index ", input_index); + } + auto* handle = new tensorflow::TensorHandle( + buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr); + op.AddInput(handle); + handle->Unref(); + } + + int num_retvals = outputs.size(); + VectorOfHandles retvals(num_retvals); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EagerExecute(&op, retvals.GetVector(), &num_retvals), + " (while executing '", op_name, "' via Eager)"); + + if (num_retvals != outputs.size()) { + return tensorflow::errors::Internal( + "Unexpected number of outputs from EagerExecute"); + } + + for (int i = 0; i < num_retvals; ++i) { + const tensorflow::Tensor* tensor = nullptr; + TF_RETURN_IF_ERROR(retvals.GetHandle(i)->Tensor(&tensor)); + buffer_map->SetFromTensorFlow(outputs[i], *tensor); + } + + return tensorflow::Status::OK(); +} + +// A single node within the larger 'op'. Note that this kernel executes many +// TensorFlow ops within a single TF Lite op. +struct OpNode { + // The name of the TensorFlow op to execute. + string name; + // The corresponding NodeDef, containing the attributes for the op. + tensorflow::NodeDef nodedef; + // List of inputs, as TF Lite tensor indices. + std::vector inputs; + // List of outputs, as TF Lite tensor indices. + std::vector outputs; +}; + +// The Larger 'op', which contains all the nodes in a supported subgraph. +struct OpData { + tensorflow::EagerContext* eager_context; + BufferMap* buffer_map; + std::vector nodes; + std::vector subgraph_inputs; + std::vector subgraph_outputs; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + CHECK(params); + CHECK(params->delegate); + CHECK(params->delegate->data_); + op_data->eager_context = + reinterpret_cast(params->delegate->data_) + ->GetEagerContext(); + op_data->buffer_map = + reinterpret_cast(params->delegate->data_)->GetBufferMap(); + + CHECK(params->output_tensors); + for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { + op_data->subgraph_outputs.push_back(tensor_index); + } + + CHECK(params->input_tensors); + for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) { + op_data->subgraph_inputs.push_back(tensor_index); + } + + CHECK(params->nodes_to_replace); + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + + op_data->nodes.push_back(OpNode()); + OpNode& node_data = op_data->nodes.back(); + + node_data.name = ""; + if (node->custom_initial_data) { + // The flexbuffer contains a vector where the first elements is the + // op name and the second is a serialized NodeDef. + const flexbuffers::Vector& v = + flexbuffers::GetRoot( + reinterpret_cast(node->custom_initial_data), + node->custom_initial_data_size) + .AsVector(); + + node_data.name = v[0].AsString().str(); + if (!node_data.nodedef.ParseFromString(v[1].AsString().str())) { + // We will just leave the nodedef empty and error out in Eval(). + node_data.nodedef.Clear(); + } + } + + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + node_data.inputs.push_back(input_index); + } + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + node_data.outputs.push_back(output_index); + } + } + + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_MSG( + context, op_data->eager_context != nullptr, + "Failed to initialize eager context. This often happens when a CPU " + "device has not been registered, presumably because some symbols from " + "tensorflow/core:core_cpu_impl were not linked into the binary."); + + // Whenever we find a constant tensor, insert it in the buffer map. + BufferMap* buffer_map = op_data->buffer_map; + for (auto tensor_index : op_data->subgraph_inputs) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (IsConstantTensor(tensor)) { + if (!buffer_map->HasTensor(tensor_index)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } + } + } + + // All output tensors are allocated by TensorFlow/Eager, so we + // mark them as kTfLiteDynamic. + for (auto tensor_index : op_data->subgraph_outputs) { + SetTensorToDynamic(&context->tensors[tensor_index]); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + BufferMap* buffer_map = op_data->buffer_map; + tensorflow::EagerContext* eager_context = op_data->eager_context; + + // Insert a tensor in the buffer map for all inputs that are not constant. + // Constants were handled in Prepare() already. + for (auto tensor_index : op_data->subgraph_inputs) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (!IsConstantTensor(tensor)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } + } + + // Execute the TensorFlow Ops sequentially. + for (const auto& node_data : op_data->nodes) { + if (node_data.nodedef.op().empty()) { + context->ReportError(context, "Invalid NodeDef in Eager op '%s'", + node_data.name.c_str()); + return kTfLiteError; + } + auto status = + ExecuteEagerOp(eager_context, buffer_map, node_data.name, + node_data.nodedef, node_data.inputs, node_data.outputs); + TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); + } + + for (auto tensor_index : op_data->subgraph_outputs) { + if (!buffer_map->HasTensor(tensor_index)) { + context->ReportError(context, "Cannot write to invalid tensor index %d", + tensor_index); + return kTfLiteError; + } + + TfLiteTensor* tensor = &context->tensors[tensor_index]; + TF_LITE_ENSURE_OK( + context, + CopyShape(context, buffer_map->GetTensor(tensor_index), tensor)); + tensor->buffer_handle = tensor_index; + tensor->data_is_stale = true; + } + + return kTfLiteOk; +} + +} // namespace kernel + +TfLiteRegistration GetKernel() { + TfLiteRegistration registration{&kernel::Init, &kernel::Free, + &kernel::Prepare, &kernel::Eval, + nullptr, kTfLiteBuiltinDelegate}; + return registration; +} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..100672c82dcd3eaee17325f3b712140b081e8efe --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace eager { + +// Return the registration object used to initialize and execute ops that will +// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by +// the eager delegate to handle execution of a supported subgraph. The usual +// flow is that the delegate informs the interpreter of supported nodes in a +// graph, and each supported subgraph is replaced with one instance of this +// kernel. +TfLiteRegistration GetKernel(); + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7bfbb34e49c71142e28f0bf1b2f84e0ff570734 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc @@ -0,0 +1,228 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" + +#include +#include +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +namespace tflite { +namespace eager { +namespace { + +using ::testing::ContainsRegex; +using ::testing::ElementsAre; + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, + const std::vector& supported_nodes) { + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels( + context, eager::GetKernel(), size_and_nodes, delegate)); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +class KernelTest : public testing::EagerModelTest { + public: + KernelTest() { + CHECK(DelegateData::Create(&delegate_data_).ok()); + interpreter_.reset(new Interpreter(&error_reporter_)); + } + + ~KernelTest() override { + // The data needs to be released before the interpreter because the + // interpreter references the data. + delegate_data_.reset(); + interpreter_.reset(); + } + + template + void ConfigureDelegate(T prepare_function) { + delegate_.data_ = delegate_data_.get(); + delegate_.FreeBufferHandle = nullptr; + delegate_.Prepare = prepare_function; + delegate_.CopyFromBufferHandle = [](TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size) { + auto* delegate_data = reinterpret_cast(delegate->data_); + tensorflow::StringPiece values = + delegate_data->GetBufferMap()->GetTensor(buffer_handle).tensor_data(); + memcpy(data, values.data(), values.size()); + return kTfLiteOk; + }; + CHECK(interpreter_->ModifyGraphWithDelegate( + &delegate_, /*allow_dynamic_tensors=*/true) == kTfLiteOk); + } + + private: + std::unique_ptr delegate_data_; + TfLiteDelegate delegate_; +}; + +TEST_F(KernelTest, FullGraph) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply Delegate. + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 3, 4}); + }); + + // Define inputs. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(KernelTest, BadTensorFlowOp) { + AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); + AddTfOp(testing::kNonExistent, {0}, {1}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("while processing attributes of 'NonExistentOp'")); +} + +TEST_F(KernelTest, BadNumberOfOutputs) { + AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3}); + AddTfOp(testing::kIdentity, {0}, {1, 2}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("Unexpected number of outputs")); +} + +TEST_F(KernelTest, IncompatibleNodeDef) { + AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); + + // Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp. + AddTfOp(testing::kIncompatibleNodeDef, {0}, {1}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("while executing 'Cast' via Eager")); +} + +TEST_F(KernelTest, WrongSetOfNodes) { + AddTensors(4, {0}, {3}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfLiteMulOp({1, 2}, {3}); + + // Specify that testing::kMul (#1) is supported when it actually isn't. + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("Invalid NodeDef in Eager op")); +} + +TEST_F(KernelTest, MixedGraph) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfLiteMulOp({6, 7}, {8}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 3}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(KernelTest, SplitGraph) { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + + AddTfLiteMulOp({4, 5}, {6}); + + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 4, 5}); + }); + + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(9), ElementsAre(1)); + ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..80acf5d9955f92ec06844b4bc3b980b3a924ab8f --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc @@ -0,0 +1,154 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +#include "absl/memory/memory.h" +#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h" + +namespace tflite { +namespace eager { +namespace testing { + +bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } + +void EagerModelTest::SetValues(int tensor_index, + const std::vector& values) { + float* v = interpreter_->typed_tensor(tensor_index); + for (float f : values) { + *v++ = f; + } +} + +std::vector EagerModelTest::GetValues(int tensor_index) { + TfLiteTensor* o = interpreter_->tensor(tensor_index); + return std::vector(o->data.f, o->data.f + o->bytes / sizeof(float)); +} + +void EagerModelTest::SetShape(int tensor_index, + const std::vector& values) { + ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); +} + +std::vector EagerModelTest::GetShape(int tensor_index) { + std::vector result; + auto* dims = interpreter_->tensor(tensor_index)->dims; + result.reserve(dims->size); + for (int i = 0; i < dims->size; ++i) { + result.push_back(dims->data[i]); + } + return result; +} + +void EagerModelTest::AddTensors(int num_tensors, const std::vector& inputs, + const std::vector& outputs, + const TfLiteType& type, + const std::vector& dims) { + interpreter_->AddTensors(num_tensors); + for (int i = 0; i < num_tensors; ++i) { + TfLiteQuantizationParams quant; + CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type, + /*name=*/"", + /*dims=*/dims, quant), + kTfLiteOk); + } + + CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk); + CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk); +} + +void EagerModelTest::AddTfLiteMulOp(const std::vector& inputs, + const std::vector& outputs) { + static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.builtin_code = BuiltinOperator_MUL; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + auto* i0 = &context->tensors[node->inputs->data[0]]; + auto* o = &context->tensors[node->outputs->data[0]]; + return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims)); + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + auto* i0 = &context->tensors[node->inputs->data[0]]; + auto* i1 = &context->tensors[node->inputs->data[1]]; + auto* o = &context->tensors[node->outputs->data[0]]; + for (int i = 0; i < o->bytes / sizeof(float); ++i) { + o->data.f[i] = i0->data.f[i] * i1->data.f[i]; + } + return kTfLiteOk; + }; + + CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0, + nullptr, ®), + kTfLiteOk); +} + +void EagerModelTest::AddTfOp(TfOpType op, const std::vector& inputs, + const std::vector& outputs) { + auto attr = [](const string& key, const string& value) { + return " attr{ key: '" + key + "' value {" + value + "}}"; + }; + + if (op == kUnpack) { + string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") + + attr("axis", "i: 0"); + AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs); + } else if (op == kIdentity) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs); + } else if (op == kAdd) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerAdd", "Add", attributes, inputs, outputs); + } else if (op == kMul) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerMul", "Mul", attributes, inputs, outputs); + } else if (op == kNonExistent) { + AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); + } else if (op == kIncompatibleNodeDef) { + // "Cast" op is created without attributes - making it incompatible. + AddTfOp("EagerCast", "Cast", "", inputs, outputs); + } +} + +void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, + const std::vector& inputs, + const std::vector& outputs) { + static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.builtin_code = BuiltinOperator_CUSTOM; + reg.custom_name = tflite_name; + + tensorflow::NodeDef nodedef; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + nodedef_str + " op: '" + tf_name + "'", &nodedef)); + string serialized_nodedef; + CHECK(nodedef.SerializeToString(&serialized_nodedef)); + flexbuffers::Builder fbb; + fbb.Vector([&]() { + fbb.String(nodedef.op()); + fbb.String(serialized_nodedef); + }); + fbb.Finish(); + + flexbuffers_.push_back(fbb.GetBuffer()); + auto& buffer = flexbuffers_.back(); + CHECK_EQ(interpreter_->AddNodeWithParameters( + inputs, outputs, reinterpret_cast(buffer.data()), + buffer.size(), nullptr, ®), + kTfLiteOk); +} + +} // namespace testing +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0eab9e1135f02b4f22a4b36a85cf6771fbbb81d5 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/test_util.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace eager { +namespace testing { + +enum TfOpType { + kUnpack, + kIdentity, + kAdd, + kMul, + // Represents an op that does not exist in TensorFlow. + kNonExistent, + // Represents an valid TensorFlow op where the NodeDef is incompatible. + kIncompatibleNodeDef, +}; + +// This class creates models with TF and TFLite ops. In order to use this class +// to test the Eager delegate, implement a function that calls +// interpreter->ModifyGraphWithDelegate. +class EagerModelTest : public ::testing::Test { + public: + EagerModelTest() {} + ~EagerModelTest() {} + + bool Invoke(); + + // Sets the tensor's values at the given index. + void SetValues(int tensor_index, const std::vector& values); + + // Returns the tensor's values at the given index. + std::vector GetValues(int tensor_index); + + // Sets the tensor's shape at the given index. + void SetShape(int tensor_index, const std::vector& values); + + // Returns the tensor's shape at the given index. + std::vector GetShape(int tensor_index); + + const TestErrorReporter& error_reporter() const { return error_reporter_; } + + // Adds `num_tensor` tensors to the model. `inputs` contains the indices of + // the input tensors and `outputs` contains the indices of the output + // tensors. All tensors are set to have `type` and `dims`. + void AddTensors(int num_tensors, const std::vector& inputs, + const std::vector& outputs, const TfLiteType& type, + const std::vector& dims); + + // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors + // and `outputs` contains the indices of the output tensors. + void AddTfLiteMulOp(const std::vector& inputs, + const std::vector& outputs); + + // Adds a TensorFlow op. `inputs` contains the indices of the + // input tensors and `outputs` contains the indices of the output tensors. + // This function is limited to the set of ops defined in TfOpType. + void AddTfOp(TfOpType op, const std::vector& inputs, + const std::vector& outputs); + + protected: + std::unique_ptr interpreter_; + TestErrorReporter error_reporter_; + + private: + // Helper method to add a TensorFlow op. tflite_names needs to start with + // "Eager" in order to work with the Eager delegate. + void AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, const std::vector& inputs, + const std::vector& outputs); + + std::vector> flexbuffers_; +}; + +} // namespace testing +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc index 4426c653e6ff80aac52b50e06a3005173490433d..c8aa0b7f69f8f6bd3bff52b13f3cc7d689a514da 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.cc +++ b/tensorflow/contrib/lite/delegates/eager/util.cc @@ -13,10 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/eager/constants.h" namespace tflite { namespace eager { +bool IsEagerOp(const char* custom_name) { + return custom_name && strncmp(custom_name, kCustomCodePrefix, + strlen(kCustomCodePrefix)) == 0; +} + TfLiteStatus ConvertStatus(TfLiteContext* context, const tensorflow::Status& status) { if (!status.ok()) { diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h index a9407be071192e9b7f25f95df9e76a5f44e7c9e3..b7363361bec47f30e0741e3a76a5a375d7d9aeb1 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.h +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -23,6 +23,10 @@ limitations under the License. namespace tflite { namespace eager { +// Checks whether the prefix of the custom name indicates the operation is an +// Eager operation. +bool IsEagerOp(const char* custom_name); + // Converts a tensorflow:Status into a TfLiteStatus. If the original status // represented an error, reports it using the given 'context'. TfLiteStatus ConvertStatus(TfLiteContext* context, diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc index c4fbf5412776a2c5743e8d72fc6729cfd709c545..4e92da8d34f4f6c9b9c1ecd959cfaed25051f826 100644 --- a/tensorflow/contrib/lite/delegates/eager/util_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc @@ -102,6 +102,16 @@ TEST(UtilTest, TypeConversions) { EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); } +TEST(UtilTest, IsEagerOp) { + EXPECT_TRUE(IsEagerOp("Eager")); + EXPECT_TRUE(IsEagerOp("EagerOp")); + EXPECT_FALSE(IsEagerOp("eager")); + EXPECT_FALSE(IsEagerOp("Eage")); + EXPECT_FALSE(IsEagerOp("OpEager")); + EXPECT_FALSE(IsEagerOp(nullptr)); + EXPECT_FALSE(IsEagerOp("")); +} + } // namespace } // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD index 091f8fbce734b466de33bb4b84e5e0fc3e4a71ef..954955f24b87f79a8dbe2863f608d532e25902c6 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/BUILD +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -22,7 +22,10 @@ tf_cc_test( name = "nnapi_delegate_test", size = "small", srcs = ["nnapi_delegate_test.cc"], - tags = ["no_oss"], + tags = [ + "no_oss", + "noasan", # TODO(b/112326936): re-enable for asan once fixed. + ], deps = [ ":nnapi_delegate", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index 0c7f6d3125b9cafd15987f132a5f417e4c2f0c6c..b1b8e9890c99b0e9fdbf735bf3f04795e8577203 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -142,6 +142,12 @@ class NNAPIOpBuilder { ANEURALNETWORKS_TENSOR_INT32); } + TfLiteStatus AddVectorFloat32Operand(const float* values, + uint32_t num_values) { + return AddVectorOperand(values, num_values, + ANEURALNETWORKS_TENSOR_FLOAT32); + } + TfLiteStatus AddPoolingParams(void* data) { auto builtin = reinterpret_cast(data); AddScalarInt32Operand(builtin->padding); @@ -167,6 +173,37 @@ class NNAPIOpBuilder { return kTfLiteOk; } + TfLiteStatus AddAdditionalFloat32OutputTensor(uint32_t dimension_count) { + std::vector dims(dimension_count, 0); + ANeuralNetworksOperandType operand_type{ + .type = ANEURALNETWORKS_TENSOR_FLOAT32, + .dimensionCount = dimension_count, + .dimensions = dims.data()}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + augmented_outputs_.push_back(ann_operand); + return kTfLiteOk; + } + + TfLiteStatus AddStateFloat32Tensor(int tensor_index, + int* ann_tensor_index_out) { + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + int ann_index = operand_mapping_->add_new_non_tensor_operand(); + + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), tensor->params.scale, + tensor->params.zero_point}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + augmented_inputs_.push_back(ann_index); + + *ann_tensor_index_out = ann_index; + return kTfLiteOk; + } + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. // This returns the NN API tensor index corresponding to the created tensor. // If another caller previously created a NN API tensor for `tensor_index` @@ -198,6 +235,10 @@ class NNAPIOpBuilder { nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; scale = tensor->params.scale; zeroPoint = tensor->params.zero_point; + if (scale == 0) { + // TENSOR_QUANT8_ASYMM with zero scale is not valid in NNAPI. + scale = 1; + } break; case kTfLiteInt32: nn_type = ANEURALNETWORKS_TENSOR_INT32; @@ -290,9 +331,10 @@ class NNAPIDelegateKernel { public: NNAPIDelegateKernel() = default; - typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, - NNAPIOpBuilder* builder, - TfLiteNode* node); + typedef ANeuralNetworksOperationType (*MappingFn)( + TfLiteContext*, NNAPIOpBuilder* builder, TfLiteNode* node, + std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs); // Return a function that knows how to translate a node into its operands // when called. You can use this function to see if a node is supported @@ -303,7 +345,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinAdd: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); builder->AddScalarInt32Operand(builtin->activation); @@ -316,7 +360,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinMul: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); builder->AddScalarInt32Operand(builtin->activation); @@ -329,7 +375,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinAveragePool2d: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { builder->AddPoolingParams(node->builtin_data); return ANEURALNETWORKS_AVERAGE_POOL_2D; }; @@ -340,7 +388,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinMaxPool2d: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { builder->AddPoolingParams(node->builtin_data); return ANEURALNETWORKS_MAX_POOL_2D; }; @@ -351,7 +401,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinL2Pool2d: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { builder->AddPoolingParams(node->builtin_data); return ANEURALNETWORKS_L2_POOL_2D; }; @@ -369,7 +421,9 @@ class NNAPIDelegateKernel { return nullptr; } return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); builder->AddScalarInt32Operand(builtin->padding); @@ -385,7 +439,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinDepthwiseConv2d: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( node->builtin_data); builder->AddScalarInt32Operand(builtin->padding); @@ -402,7 +458,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinFullyConnected: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( node->builtin_data); builder->AddScalarInt32Operand(builtin->activation); @@ -415,7 +473,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinSoftmax: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); builder->AddScalarFloat32Operand(builtin->beta); @@ -428,7 +488,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinReshape: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { return ANEURALNETWORKS_RESHAPE; }; } else { @@ -436,10 +498,11 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinSqueeze: - // Squeeze requires NNAPI1.1. if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); // Note that we add the squeeze dimensions even if the dimensions @@ -460,13 +523,251 @@ class NNAPIDelegateKernel { return nullptr; } return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { return ANEURALNETWORKS_L2_NORMALIZATION; }; } + case kTfLiteBuiltinLocalResponseNormalization: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->radius); + builder->AddScalarFloat32Operand(builtin->bias); + builder->AddScalarFloat32Operand(builtin->alpha); + builder->AddScalarFloat32Operand(builtin->beta); + return ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION; + }; + } else { + // TODO(miaowang): clean-up code and return early in the unsupported + // case. + return nullptr; + } + break; + case kTfLiteBuiltinLshProjection: + if (version == 1) { + // NNAPI does not support sparse projection correctly (b/111751836). + if (reinterpret_cast(node->builtin_data) + ->type == kTfLiteLshProjectionSparse) { + return nullptr; + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->type); + return ANEURALNETWORKS_LSH_PROJECTION; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinConcatenation: + if (version == 1 && + reinterpret_cast(node->builtin_data) + ->activation == kTfLiteActNone) { + if (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8) { + // NNAPI only support concatenating quantized tensor of the same + // scale and offset. + auto first_param = context->tensors[node->inputs->data[0]].params; + for (int i = 0; i < node->inputs->size; i++) { + auto curr_param = context->tensors[node->inputs->data[i]].params; + if (curr_param.scale != first_param.scale || + curr_param.zero_point != first_param.zero_point) { + return nullptr; + } + } + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->axis); + return ANEURALNETWORKS_CONCATENATION; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDequantize: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_DEQUANTIZE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinFloor: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_FLOOR; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinRelu: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinReluN1To1: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU1; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinRelu6: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU6; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinLogistic: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_LOGISTIC; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinTanh: + // TODO(miaowang): add additional checks for the parameters. + if (version == 1 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float tanh. + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_TANH; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSub: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float sub. + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_SUB; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDiv: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float div. + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_DIV; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinPad: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + node->inputs->size == 2 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI does not support specifying the padding value. + // NNAPI pads physical zero for quantized tensors, so only delegate + // float pad to NNAPI. + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_PAD; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSpaceToBatchNd: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_SPACE_TO_BATCH_ND; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinStridedSlice: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->begin_mask); + builder->AddScalarInt32Operand(builtin->end_mask); + builder->AddScalarInt32Operand(builtin->shrink_axis_mask); + return ANEURALNETWORKS_STRIDED_SLICE; + }; + } else { + return nullptr; + } + break; case kTfLiteBuiltinTranspose: - // Transpose requires NNAPI1.1. Also note that the permutation input - // tensor value dictates the output dimensions. + // Note that the permutation input tensor value dictates the output + // dimensions. // TODO(b/110888333): Support dynamically-sized tensors in delegates. if ((version == 1) && (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) && @@ -474,13 +775,155 @@ class NNAPIDelegateKernel { (context->tensors[node->inputs->data[1]].allocation_type == kTfLiteMmapRo)) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { return ANEURALNETWORKS_TRANSPOSE; }; } else { return nullptr; } break; + case kTfLiteBuiltinRnn: + // NNAPI only support float32 weights. + // TODO(miaowang): check the number of inputs before accessing it. + if (version == 1 && + context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type == + kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out. + int ann_index; + builder->AddStateFloat32Tensor( + node->outputs->data[/*kHiddenStateTensor*/ 0], &ann_index); + model_state_inputs->push_back(ann_index); + model_state_tfl_outputs->push_back( + node->outputs->data[/*kHiddenStateTensor*/ 0]); + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_RNN; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSvdf: + // NNAPI only support float32 weights. + if (version == 1 && + context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]] + .type == kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out. + int ann_index; + builder->AddStateFloat32Tensor( + node->outputs->data[/*kStateTensor*/ 0], &ann_index); + model_state_inputs->push_back(ann_index); + model_state_tfl_outputs->push_back( + node->outputs->data[/*kStateTensor*/ 0]); + + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->rank); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_SVDF; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinLstm: + // NNAPI only support float32 weights. + // TODO(miaowang): add loggings to indicate why the op is rejected. + if (version == 1 && node->inputs->size == 18 && + context->tensors[node->inputs + ->data[/*kInputToOutputWeightsTensor*/ 4]] + .type == kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out for cell_state and + // output_state. + int ann_index; + builder->AddStateFloat32Tensor( + node->outputs->data[/*kOutputStateTensor*/ 0], &ann_index); + model_state_inputs->push_back(ann_index); + model_state_tfl_outputs->push_back( + node->outputs->data[/*kOutputStateTensor*/ 0]); + builder->AddStateFloat32Tensor( + node->outputs->data[/*kCellStateTensor*/ 1], &ann_index); + model_state_inputs->push_back(ann_index); + model_state_tfl_outputs->push_back( + node->outputs->data[/*kCellStateTensor*/ 1]); + + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + builder->AddScalarFloat32Operand(builtin->cell_clip); + builder->AddScalarFloat32Operand(builtin->proj_clip); + + // Current NNAPI implementation requires the sratch_buffer as + // output. + builder->AddAdditionalFloat32OutputTensor(2); + return ANEURALNETWORKS_LSTM; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMean: + // NNAPI does not support generating a scalar as output for MEAN. + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 && + context->tensors[node->outputs->data[0]].dims->size > 0) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + int32_t keep_dims = 0; + if (builtin->keep_dims) keep_dims = 1; + builder->AddScalarInt32Operand(keep_dims); + return ANEURALNETWORKS_MEAN; + }; + } else { + return nullptr; + } + case kTfLiteBuiltinEmbeddingLookup: + // NNAPI only support float32 values. + if (version == 1 && + context->tensors[node->inputs->data[1]].type == kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_EMBEDDING_LOOKUP; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinHashtableLookup: + // NNAPI only support float32 output. + if (version == 1 && + context->tensors[node->outputs->data[0]].type == kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_HASHTABLE_LOOKUP; + }; + } else { + return nullptr; + } + break; default: return nullptr; } @@ -520,7 +963,12 @@ class NNAPIDelegateKernel { // Set the input tensor buffers. Note: we access tflite tensors using // absolute indices but NN api indices inputs by relative indices. int relative_input_index = 0; + int num_optional_tensors = 0; for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + if (absolute_input_index == kOptionalTensor) { + num_optional_tensors++; + continue; + } TfLiteTensor* tensor = &context->tensors[absolute_input_index]; // TODO(miaowang): make sure the delegation works with dequantized weights // as intermediate tensors. @@ -541,6 +989,20 @@ class NNAPIDelegateKernel { tensor->data.raw, tensor->bytes)); relative_output_index++; } + + // The state_out of previous invocation need to be mapped to state_in of + // current invocation. + for (size_t i = 0; i < model_state_tfl_outputs_.size(); i++) { + int state_tensor_idx = model_state_tfl_outputs_[i]; + TfLiteTensor* tensor = &context->tensors[state_tensor_idx]; + // Here we are using a deep copy for state_in tensors so that we are not + // reading and writing into the same buffer during a invocation. + // TODO(110369471): using double shared buffer to minimize the copies. + CHECK_NN(context, + ANeuralNetworksExecution_setInput( + execution, i + node->inputs->size - num_optional_tensors, + nullptr, tensor->data.raw, tensor->bytes)); + } // Invoke ANN in blocking fashion. ANeuralNetworksEvent* event = nullptr; CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); @@ -562,6 +1024,9 @@ class NNAPIDelegateKernel { // Track indices we use OperandMapping operand_mapping_; + std::vector model_state_inputs_; + std::vector model_state_tfl_outputs_; + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { // The operand builder allows creating a single op. We create it at this // reduced power position rather than in the for loop to avoid reallocating @@ -576,11 +1041,22 @@ class NNAPIDelegateKernel { context->GetNodeAndRegistration(context, node_index, &node, ®); // Map inputs to NN API tensor indices. for (auto input_index : TfLiteIntArrayView(node->inputs)) { - TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + if (input_index == kOptionalTensor && + (reg->builtin_code == kTfLiteBuiltinLstm || + reg->builtin_code == kTfLiteBuiltinSvdf)) { + // properly handle the optional tensor for LSTM and SVDF. + // currently only support float32. + // TODO(miaowang): make sure this is also able to handle quantized + // tensor when supported by NNAPI. + TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0)); + } else { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } } // Get op type and operands int nn_op_type = Map(context, reg->builtin_code, reg->version, node)( - context, &builder, node); + context, &builder, node, &model_state_inputs_, + &model_state_tfl_outputs_); // Map outputs to NN API tensor indices. for (auto output_index : TfLiteIntArrayView(node->outputs)) { TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); @@ -604,12 +1080,20 @@ class NNAPIDelegateKernel { // Make the TensorFlow lite inputs and outputs to ann_indices. for (int i : TfLiteIntArrayView(input_tensors)) { // Constant tensors are not NNAPI inputs. - if (context->tensors[i].allocation_type != kTfLiteMmapRo) { + if (i != kOptionalTensor && + context->tensors[i].allocation_type != kTfLiteMmapRo) { inputs.push_back(operand_mapping_.lite_index_to_ann(i)); } } - for (int i : TfLiteIntArrayView(output_tensors)) + // Add state input tensors as model inputs + for (int i : model_state_inputs_) { + inputs.push_back(i); + } + + for (int i : TfLiteIntArrayView(output_tensors)) { outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + } + // Tell ANN to declare inputs/outputs CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_.get(), inputs.size(), inputs.data(), diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index baf8046f9b16832ee24f7689f02563e381c601de..3224b23a0c3bc8456bd75f2923d16f0eed7d53ff 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -82,7 +82,7 @@ TEST(NNAPIDelegate, AddWithRelu) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); } -class FloatMulOpModel : public SingleOpModel { +class FloatMulOpModel : public SingleOpModelWithNNAPI { public: FloatMulOpModel(const TensorData& input1, const TensorData& input2, const TensorData& output, @@ -184,7 +184,7 @@ TEST(NNAPIDelegate, L2PoolWithNoActivation) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); } -class BaseConvolutionOpModel : public SingleOpModel { +class BaseConvolutionOpModel : public SingleOpModelWithNNAPI { public: BaseConvolutionOpModel( const TensorData& input, const TensorData& filter, @@ -713,6 +713,2808 @@ TEST(NNAPIDelegate, TransposeSimpleTest) { 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); } +class FloatSubOpModel : public SingleOpModelWithNNAPI { + public: + FloatSubOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, SubWithNoActivation) { + FloatSubOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-2.1, 0.0, 0.4, 0.3}))); +} + +class FloatDivOpModel : public SingleOpModelWithNNAPI { + public: + FloatDivOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_DIV, BuiltinOptions_DivOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, DivWithNoActivation) { + FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.8, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.4, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-20, 1, 2, 4}))); +} + +class BaseConcatenationOpModel : public SingleOpModelWithNNAPI { + public: + BaseConcatenationOpModel() {} + BaseConcatenationOpModel(const TensorData& input_template, int axis, + int num_inputs) { + std::vector> all_input_shapes; + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + + protected: + int output_; +}; + +class ConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(NNAPIDelegate, ConcatenationThreeDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(NNAPIDelegate, ConcatenationFourInputs) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2, + /*num_inputs=*/4); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + })); +} + +class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + QuantizedConcatenationOpModel(const std::vector& input_template, + int axis, int num_inputs, + const TensorData& output_template) { + std::vector> all_input_shapes; + CHECK_EQ(input_template.size(), num_inputs); + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template[i].shape); + AddInput(input_template[i]); + } + output_ = AddOutput({output_template.type, /*shape=*/{}, + output_template.min, output_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + void SetInput(int index, std::initializer_list data) { + QuantizeAndPopulate(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(NNAPIDelegate, ConcatenationFourInputsQuantized) { + QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}, + /*axis=*/2, + /*num_inputs=*/4); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) { + QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8}, + {TensorType_UINT8, {2, 1, 2}, 0, 12.8}, + {TensorType_UINT8, {2, 1, 2}, -11, 11.8}, + {TensorType_UINT8, {2, 1, 2}, 0, 7.4}}, + /*axis=*/2, /*num_inputs=*/4, + {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +class DequantizeOpModel : public SingleOpModelWithNNAPI { + public: + DequantizeOpModel(std::initializer_list shape, float min, float max) { + input_ = AddInput({TensorType_UINT8, shape, min, max}); + output_ = AddOutput({TensorType_FLOAT32, shape}); + SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions, + CreateDequantizeOptions(builder_).Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, DequantizeFourDimensional) { + DequantizeOpModel m({2, 5}, -63.5, 64); + + m.SetInput({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64}))); +} + +class FloorOpModel : public SingleOpModelWithNNAPI { + public: + FloorOpModel(std::initializer_list input_shape, TensorType input_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0); + BuildInterpreter({ + input_shape, + }); + } + + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, FloorSingleDim) { + FloorOpModel model({2}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), {8.5, 0.0}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(NNAPIDelegate, FloorMultiDims) { + FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), { + 0.0001, + 8.0001, + 0.9999, + 9.9999, + 0.5, + -0.0001, + -8.0001, + -0.9999, + -9.9999, + -0.5, + }); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5})); +} + +class LocalResponseNormOpModel : public SingleOpModelWithNNAPI { + public: + LocalResponseNormOpModel(std::initializer_list input_shape, int radius, + float bias, float alpha, float beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOptions_LocalResponseNormalizationOptions, + CreateLocalResponseNormalizationOptions(builder_, radius, bias, + alpha, beta) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, LocalResponseNormSameAsL2Norm) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/1.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 2. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}))); +} + +TEST(NNAPIDelegate, LocalResponseNormWithAlpha) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 3. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}))); +} + +TEST(NNAPIDelegate, LocalResponseNormWithBias) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 5. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}))); +} + +TEST(NNAPIDelegate, LocalResponseNormSmallRadius) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266}))); +} + +class LSHProjectionOpModel : public SingleOpModelWithNNAPI { + public: + LSHProjectionOpModel(LSHProjectionType type, + std::initializer_list hash_shape, + std::initializer_list input_shape, + std::initializer_list weight_shape) { + hash_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_INT32); + if (weight_shape.size() > 0) { + weight_ = AddInput(TensorType_FLOAT32); + } + output_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_LSH_PROJECTION, + BuiltinOptions_LSHProjectionOptions, + CreateLSHProjectionOptions(builder_, type).Union()); + if (weight_shape.size() > 0) { + BuildInterpreter({hash_shape, input_shape, weight_shape}); + } else { + BuildInterpreter({hash_shape, input_shape}); + } + + output_size_ = 1; + for (int i : hash_shape) { + output_size_ *= i; + if (type == LSHProjectionType_SPARSE) { + break; + } + } + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetHash(std::initializer_list data) { + PopulateTensor(hash_, data); + } + + void SetWeight(std::initializer_list f) { PopulateTensor(weight_, f); } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int hash_; + int weight_; + int output_; + + int output_size_; +}; + +TEST(NNAPIDelegate, LSHProjectionDense1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0)); +} + +TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0)); +} + +TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5}); + + m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912, + 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1)); +} + +class BaseActivationsOpModel : public SingleOpModelWithNNAPI { + public: + // Most activations don't take any options, so this constructor works for + // them. + BaseActivationsOpModel(BuiltinOperator type, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + BaseActivationsOpModel(BuiltinOperator type, const TensorData& input, + const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +const float kQuantizedTolerance = 2 * (1. / 256); + +class QuantizedActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + template + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + template + + std::vector GetOutput() { + return ExtractVector(output_); + } + template + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } +}; + +TEST(NNAPIDelegate, Relu) { + FloatActivationsOpModel m(BuiltinOperator_RELU, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 10, 1, // + })); +} + +TEST(NNAPIDelegate, Relu1) { + FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -1.0, 1.0, -0.1, // + })); +} + +TEST(NNAPIDelegate, Relu6) { + FloatActivationsOpModel m(BuiltinOperator_RELU6, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 6, 1, // + })); +} + +TEST(NNAPIDelegate, Tanh) { + FloatActivationsOpModel m(BuiltinOperator_TANH, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0, -0.9999877, 0.9640275, 0.999329, // + 0.99505475, -0.9640275, 1, 0.7615941, // + }))); +} + +TEST(NNAPIDelegate, LogisticFloat) { + FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }))); +} + +TEST(NNAPIDelegate, LogisticQuantized) { + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); +} + +#if 0 +class ResizeBilinearOpModel : public SingleOpModelWithNNAPI { + public: + ResizeBilinearOpModel(const TensorData& input, + std::initializer_list size_data = {}) { + bool const_size = size_data.size() != 0; + input_ = AddInput(input); + if (const_size) { + size_ = AddConstInput(TensorType_INT32, size_data, {2}); + } else { + size_ = AddInput({TensorType_INT32, {2}}); + } + output_ = AddOutput(input.type); + SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, + BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_).Union()); + if (const_size) { + BuildInterpreter({GetShape(input_)}); + } else { + BuildInterpreter({GetShape(input_), GetShape(size_)}); + } + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + private: + int input_; + int size_; + int output_; +}; + +TEST(NNAPIDelegate, ResizeBilinearHorizontal) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(NNAPIDelegate, ResizeBilinearVertical) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(NNAPIDelegate, ResizeBilinearTwoDimensional) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} +#endif + +template +class PadOpModel : public SingleOpModelWithNNAPI { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetQuantizedInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetQuantizedPadValue(float data) { + QuantizeAndPopulate(constant_values_, {data}); + } + + void SetPaddings(std::initializer_list paddings) { + PopulateTensor(paddings_, paddings); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + protected: + int input_; + int output_; + int paddings_; + int constant_values_; +}; + +class PadOpConstModel : public PadOpModel { + public: + PadOpConstModel(const TensorData& input, + std::initializer_list paddings_shape, + std::initializer_list paddings, + const TensorData& output) { + input_ = AddInput(input); + paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape); + output_ = AddOutput(output); + + SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_).Union()); + BuildInterpreter({input.shape}); + } +}; + +TEST(NNAPIDelegate, PadAdvancedConstTest) { + PadOpConstModel m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, + {0, 0, 0, 2, 1, 3, 0, 0}, {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +class SpaceToBatchNDOpModel : public SingleOpModelWithNNAPI { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetPaddings(std::initializer_list data) { + PopulateTensor(paddings_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int block_shape_; + int paddings_; + int output_; +}; + +class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list paddings) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +TEST(NNAPIDelegate, SpaceToBatchNDSimpleConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(NNAPIDelegate, SpaceToBatchNDMultipleInputBatchesConstTest) { + SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(NNAPIDelegate, SpaceToBatchNDSimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, + 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10, + })); +} + +TEST(NNAPIDelegate, SpaceToBatchNDComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, + 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0, + 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + })); +} + +template +class StridedSliceOpModel : public SingleOpModelWithNNAPI { + public: + StridedSliceOpModel(std::initializer_list input_shape, + std::initializer_list begin_shape, + std::initializer_list end_shape, + std::initializer_list strides_shape, int begin_mask, + int end_mask, int ellipsis_mask, int new_axis_mask, + int shrink_axis_mask) { + input_ = AddInput(tensor_input_type); + begin_ = AddInput(TensorType_INT32); + end_ = AddInput(TensorType_INT32); + strides_ = AddInput(TensorType_INT32); + output_ = AddOutput(tensor_input_type); + SetBuiltinOp( + BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, + CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask) + .Union()); + BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); + } + void SetEnd(std::initializer_list data) { + PopulateTensor(end_, data); + } + void SetStrides(std::initializer_list data) { + PopulateTensor(strides_, data); + } + + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int begin_; + int end_; + int strides_; + int output_; +}; + +TEST(NNAPIDelegate, StridedSliceIn2D) { + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); +} + +TEST(NNAPIDelegate, StridedSliceIn2D_ShrinkAxis_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-2, -1}); + m.SetEnd({-1, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(NNAPIDelegate, StridedSliceIn2D_ShrinkAxisMask) { + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({1, 1}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0}; + +static std::initializer_list rnn_weights = { + 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}; + +static std::initializer_list rnn_recurrent_weights = { + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}; + +static std::initializer_list rnn_bias = { + 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, + 0.37197268, 0.61957061, 0.3956964, -0.37609905}; + +class RNNOpModel : public SingleOpModelWithNNAPI { + public: + RNNOpModel(int batches, int units, int size, + const TensorType& weights = TensorType_FLOAT32, + const TensorType& recurrent_weights = TensorType_FLOAT32) + : batches_(batches), units_(units), input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(weights); + recurrent_weights_ = AddInput(recurrent_weights); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + BuildInterpreter({{batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(NNAPIDelegate, RnnBlackBoxTest) { + RNNOpModel rnn(2, 16, 8); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); + + rnn.ResetHiddenState(); + const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / + (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +static float svdf_input[] = { + 0.12609188, -0.46347019, -0.89598465, + 0.35867718, 0.36897406, 0.73463392, + + 0.14278367, -1.64410412, -0.75222826, + -0.57290924, 0.12729003, 0.7567004, + + 0.49837467, 0.19278903, 0.26584083, + 0.17660543, 0.52949083, -0.77931279, + + -0.11186574, 0.13164264, -0.05349274, + -0.72674477, -0.5683046, 0.55900657, + + -0.68892461, 0.37783599, 0.18263303, + -0.63690937, 0.44483393, -0.71817774, + + -0.81299269, -0.86831826, 1.43940818, + -0.95760226, 1.82078898, 0.71135032, + + -1.45006323, -0.82251364, -1.69082689, + -1.65087092, -1.89238167, 1.54172635, + + 0.03966608, -0.24936394, -0.77526885, + 2.06740379, -1.51439476, 1.43768692, + + 0.11771342, -0.23761693, -0.65898693, + 0.31088525, -1.55601168, -0.87661445, + + -0.89477462, 1.67204106, -0.53235275, + -0.6230064, 0.29819036, 1.06939757, +}; + +static float svdf_golden_output_rank_1[] = { + 0.014899, -0.0517661, -0.143725, -0.00271883, + -0.03004015, 0.09565311, 0.1587342, 0.00784263, + + 0.068281, -0.162217, -0.152268, 0.00323521, + 0.01582633, 0.03858774, -0.03001583, -0.02671271, + + -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.01432795, 0.05524484, 0.1101355, -0.02382665, + + -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.02333033, 0.02293761, 0.12338032, 0.04326871, + + 0.201551, -0.164607, -0.179462, -0.0592739, + 0.01064911, -0.17503069, 0.07821996, -0.00224009, + + 0.0886511, -0.0875401, -0.269283, 0.0281379, + -0.02282338, 0.09741908, 0.32973239, 0.12281385, + + -0.201174, -0.586145, -0.628624, -0.0330412, + 0.24780814, -0.39304617, -0.22473189, 0.02589256, + + -0.0839096, -0.299329, 0.108746, 0.109808, + 0.10084175, -0.06416984, 0.28936723, 0.0026358, + + 0.419114, -0.237824, -0.422627, 0.175115, + -0.2314795, -0.18584411, -0.4228974, -0.12928449, + + 0.36726, -0.522303, -0.456502, -0.175475, + 0.17012937, -0.34447709, 0.38505614, -0.28158101, +}; + +static float svdf_golden_output_rank_2[] = { + -0.09623547, -0.10193135, 0.11083051, -0.0347917, + 0.1141196, 0.12965347, -0.12652366, 0.01007236, + + -0.16396809, -0.21247184, 0.11259045, -0.04156673, + 0.10132131, -0.06143532, -0.00924693, 0.10084561, + + 0.01257364, 0.0506071, -0.19287863, -0.07162561, + -0.02033747, 0.22673416, 0.15487903, 0.02525555, + + -0.1411963, -0.37054959, 0.01774767, 0.05867489, + 0.09607603, -0.0141301, -0.08995658, 0.12867066, + + -0.27142537, -0.16955489, 0.18521598, -0.12528358, + 0.00331409, 0.11167502, 0.02218599, -0.07309391, + + 0.09593632, -0.28361851, -0.0773851, 0.17199151, + -0.00075242, 0.33691186, -0.1536046, 0.16572715, + + -0.27916506, -0.27626723, 0.42615682, 0.3225764, + -0.37472126, -0.55655634, -0.05013514, 0.289112, + + -0.24418658, 0.07540751, -0.1940318, -0.08911639, + 0.00732617, 0.46737891, 0.26449674, 0.24888524, + + -0.17225097, -0.54660404, -0.38795233, 0.08389944, + 0.07736043, -0.28260678, 0.15666828, 1.14949894, + + -0.57454878, -0.64704704, 0.73235172, -0.34616736, + 0.21120001, -0.22927976, 0.02455296, -0.35906726, +}; + +class BaseSVDFOpModel : public SingleOpModelWithNNAPI { + public: + BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank, + TensorType weights_feature_type = TensorType_FLOAT32, + TensorType weights_time_type = TensorType_FLOAT32) + : batches_(batches), + units_(units), + input_size_(input_size), + memory_size_(memory_size), + rank_(rank) { + input_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(weights_feature_type); + weights_time_ = AddInput(weights_time_type); + bias_ = AddNullInput(); + state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + BuildInterpreter({ + {batches_, input_size_}, // Input tensor + {units_ * rank, input_size_}, // weights_feature tensor + {units_ * rank, memory_size_}, // weights_time tensor + {units_} // bias tensor + }); + } + + // Populates the weights_feature tensor. + void SetWeightsFeature(std::initializer_list f) { + PopulateTensor(weights_feature_, f); + } + + // Populates the weights_time tensor. + void SetWeightsTime(std::initializer_list f) { + PopulateTensor(weights_time_, f); + } + + // Populates the input tensor. + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + // Resets the state of SVDF op by filling it with 0's. + void ResetState() { + const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + // Extracts the output tensor from the SVDF op. + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_feature_; + int weights_time_; + int bias_; + int state_; + int output_; + + int batches_; + int units_; + int input_size_; + int memory_size_; + int rank_; +}; + +class SVDFOpModel : public BaseSVDFOpModel { + public: + using BaseSVDFOpModel::BaseSVDFOpModel; + + void VerifyGoldens(float golden_input[], float golden_output[], + int golden_size, float tolerance = 1e-5) { + const int svdf_num_batches = num_batches(); + const int svdf_input_size = input_size(); + const int svdf_num_units = num_units(); + const int input_sequence_size = + golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF + // op and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = + golden_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + SetInput(0, batch_start, batch_end); + + Invoke(); + + const float* golden_start = + golden_output + i * svdf_num_units * svdf_num_batches; + const float* golden_end = + golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +TEST(NNAPIDelegate, SVDFBlackBoxTestRank1) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input)); +} + +TEST(NNAPIDelegate, SVDFBlackBoxTestRank2) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input)); +} + +class LSTMOpModel : public SingleOpModelWithNNAPI { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(weight_type); + } + + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(weight_type); + } + + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(weight_type); + } + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(weight_type); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + protected: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + int input_activation_state_; + int input_cell_state_; + + int output_; + int output_state_; + int cell_state_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end); + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; + + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; + + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; + } +}; + +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + } +}; + +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class BaseReduceOpModel : public SingleOpModelWithNNAPI { + public: + void SetAxis(const std::vector& data) { PopulateTensor(axis_, data); } + + template + void SetInput(std::vector data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + int Input() { return input_; } + + protected: + int input_; + int axis_; + int output_; +}; + +// Model for the tests case where axis is a const tensor. +class MeanOpConstModel : public BaseReduceOpModel { + public: + MeanOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Tests for reduce_mean +TEST(NNAPIDelegate, MeanFloatNotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(NNAPIDelegate, MeanFloatKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +class BaseEmbeddingLookupOpModel : public SingleOpModelWithNNAPI { + public: + BaseEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape, + TensorType weight_type = TensorType_FLOAT32) { + input_ = AddInput(TensorType_INT32); + weight_ = AddInput(weight_type); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({index_shape, weight_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weight_; + int output_; +}; + +class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } +}; + +TEST(NNAPIDelegate, EmbeddingLookupSimpleTest) { + EmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetInput({1, 0, 2}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }))); +} + +class HashtableLookupOpModel : public SingleOpModelWithNNAPI { + public: + HashtableLookupOpModel(std::initializer_list lookup_shape, + std::initializer_list key_shape, + std::initializer_list value_shape, + TensorType type) { + lookup_ = AddInput(TensorType_INT32); + key_ = AddInput(TensorType_INT32); + value_ = AddInput(type); + output_ = AddOutput(type); + hit_ = AddOutput(TensorType_UINT8); + SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({lookup_shape, key_shape, value_shape}); + } + + void SetLookup(std::initializer_list data) { + PopulateTensor(lookup_, data); + } + + void SetHashtableKey(std::initializer_list data) { + PopulateTensor(key_, data); + } + + void SetHashtableValue(const std::vector& content) { + PopulateStringTensor(value_, content); + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + for (int i = 0; i < rows; i++) { + tensor->data.f[i] = function(i); + } + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int features = tensor->dims->data[1]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < features; j++) { + tensor->data.f[i * features + j] = function(i, j); + } + } + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetHit() { return ExtractVector(hit_); } + + private: + int lookup_; + int key_; + int value_; + int output_; + int hit_; +}; + +TEST(NNAPIDelegate, HashtableLookupTest2DInput) { + HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 2.0, 2.1, // 2-nd item + 0, 0, // Not found + 0.0, 0.1, // 0-th item + 1.0, 1.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +TEST(NNAPIDelegate, HashtableLookupTest1DInput) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i) { return i * i / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.4, // 2-nd item + 0, // Not found + 0.0, // 0-th item + 0.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index d74e275f0439b1ce56b29e0eadff5f211f6a4faa..30fee64a6f621016446eff58c305e88fda01fa76 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -315,7 +315,7 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const labelLayers = [[NSMutableArray alloc] init]; oldPredictionValues = [[NSMutableDictionary alloc] init]; - NSString* graph_path = FilePathForResourceName(model_file_name, @"tflite"); + NSString* graph_path = FilePathForResourceName(model_file_name, model_file_type); model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]); if (!model) { LOG(FATAL) << "Failed to mmap model " << graph_path; diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile index c7d3b1c966eaa0de71f5c37a6a77b3881e30ddd7..cd8c39043f6df61ed83e75e80a42156fdba68642 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/Podfile +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_camera_example' - pod 'TensorFlowLite' + pod 'TensorFlowLite', '0.1.7' diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile index e4aca2be82d437a0225d2c15d3e486b0344aa978..c885398f44456bc1b7429b4f6605237bbc64e654 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/Podfile +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_simple_example' - pod 'TensorFlowLite' + pod 'TensorFlowLite', '0.1.7' diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index e36218e4f12057a362af47c48454f7930fc495f2..6fdcf78b69c6799fc2e666af1150efb88b55ff5c 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -16,11 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ #define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" @@ -28,8 +24,6 @@ limitations under the License. #include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/version.h" -#include "tensorflow/contrib/lite/examples/label_image/label_image.h" - namespace tflite { namespace label_image { diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 86d7d1cc4a625243791d5e7d5b746526a58efb6d..7c6f523041ad5a516f348c1b4f66683128838228 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -213,22 +213,23 @@ void RunInference(Settings* s) { } } - const int output_size = 1000; - const size_t num_results = 5; const float threshold = 0.001f; std::vector> top_results; int output = interpreter->outputs()[0]; + TfLiteIntArray* output_dims = interpreter->tensor(output)->dims; + // assume output dims to be something like (1, 1, ... ,size) + auto output_size = output_dims->data[output_dims->size - 1]; switch (interpreter->tensor(output)->type) { case kTfLiteFloat32: get_top_n(interpreter->typed_output_tensor(0), output_size, - num_results, threshold, &top_results, true); + s->number_of_results, threshold, &top_results, true); break; case kTfLiteUInt8: get_top_n(interpreter->typed_output_tensor(0), - output_size, num_results, threshold, &top_results, - false); + output_size, s->number_of_results, threshold, + &top_results, false); break; default: LOG(FATAL) << "cannot handle output type " @@ -259,6 +260,7 @@ void display_usage() { << "--labels, -l: labels for the model\n" << "--tflite_model, -m: model_name.tflite\n" << "--profiling, -p: [0|1], profiling or not\n" + << "--num_results, -r: number of results to show\n" << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "\n"; @@ -280,12 +282,13 @@ int Main(int argc, char** argv) { {"threads", required_argument, nullptr, 't'}, {"input_mean", required_argument, nullptr, 'b'}, {"input_std", required_argument, nullptr, 's'}, + {"num_results", required_argument, nullptr, 'r'}, {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; - c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:s:t:v:", long_options, + c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:r:s:t:v:", long_options, &option_index); /* Detect the end of the options. */ @@ -315,6 +318,10 @@ int Main(int argc, char** argv) { s.profiling = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; + case 'r': + s.number_of_results = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) + break; case 's': s.input_std = strtod(optarg, nullptr); break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h index 4b48014e1c77eca1eca081f0fe906441a5dcce22..34c223f713b9fe7692440a6b7538f00be995ad11 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.h +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -34,6 +34,7 @@ struct Settings { string labels_file_name = "./labels.txt"; string input_layer_type = "uint8_t"; int number_of_threads = 4; + int number_of_results = 5; }; } // namespace label_image diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD index b09bb9ea1056b90a00729677d84ae52eca2b0220..50f8da66d06abaf0637866e85c04e80fee042071 100644 --- a/tensorflow/contrib/lite/experimental/c/BUILD +++ b/tensorflow/contrib/lite/experimental/c/BUILD @@ -5,6 +5,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow/contrib/lite:build_def.bzl", "tflite_cc_shared_object", + "tflite_copts", "tflite_jni_binary", ) @@ -30,16 +31,11 @@ tflite_cc_shared_object( ], ) -tflite_jni_binary( - name = "libtensorflowlite_c_jni.so", - linkscript = ":version_script.lds", - deps = [":c_api"], -) - cc_library( name = "c_api", srcs = ["c_api.cc"], hdrs = ["c_api.h"], + copts = tflite_copts(), deps = [ "//tensorflow/contrib/lite:context", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc index add4c6813dcc6c015e6248a2aaa1addf2a266ec3..9d29e8b3e055e86a9e68285d81de742e36452215 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -27,6 +27,8 @@ struct _TFL_Interpreter { std::unique_ptr impl; }; +// LINT.IfChange + TFL_Interpreter* TFL_NewInterpreter(const void* model_data, int32_t model_size) { auto model = tflite::FlatBufferModel::BuildFromBuffer( @@ -113,6 +115,8 @@ TFL_Status TFL_TensorCopyToBuffer(const TFL_Tensor* tensor, void* output_data, return kTfLiteOk; } +// LINT.ThenChange(//tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs) + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c72a5cae9ebfb15f60961fe25e622663cad89a41 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore @@ -0,0 +1,13 @@ +# Unity generated +Builds/ +Temp/ +Library/ +obj/ +# Visual Studio / MonoDevelop generated +*.csproj +*.unityproj +*.sln +*.suo +*.userprefs +# OS generated +.DS_Store diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta new file mode 100644 index 0000000000000000000000000000000000000000..ed9337b53e880b62f70953f197613dcb1409d208 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 71d1b4219b1da4aeaa1cebbec324fc81 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta new file mode 100644 index 0000000000000000000000000000000000000000..edcce00939a298683b15ea45a5ec92709c6abc4f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: d948aead14abd4c88947c9886d16f774 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta new file mode 100644 index 0000000000000000000000000000000000000000..36b35516f0cee064c8d8e4814a2ae515e28590ce --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: b810b85b794fa48fd93100acf5525e1f +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta new file mode 100644 index 0000000000000000000000000000000000000000..d4133da49a88d38a57d074d28b903f9f18102413 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 154f4201e2e454d4696fa5834eaa3ad3 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity new file mode 100644 index 0000000000000000000000000000000000000000..bcf24b89e335781877a7046001ac4deb6fc55041 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity @@ -0,0 +1,477 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!29 &1 +OcclusionCullingSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_OcclusionBakeSettings: + smallestOccluder: 5 + smallestHole: 0.25 + backfaceThreshold: 100 + m_SceneGUID: 00000000000000000000000000000000 + m_OcclusionCullingData: {fileID: 0} +--- !u!104 &2 +RenderSettings: + m_ObjectHideFlags: 0 + serializedVersion: 8 + m_Fog: 0 + m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1} + m_FogMode: 3 + m_FogDensity: 0.01 + m_LinearFogStart: 0 + m_LinearFogEnd: 300 + m_AmbientSkyColor: {r: 0.212, g: 0.227, b: 0.259, a: 1} + m_AmbientEquatorColor: {r: 0.114, g: 0.125, b: 0.133, a: 1} + m_AmbientGroundColor: {r: 0.047, g: 0.043, b: 0.035, a: 1} + m_AmbientIntensity: 1 + m_AmbientMode: 3 + m_SubtractiveShadowColor: {r: 0.42, g: 0.478, b: 0.627, a: 1} + m_SkyboxMaterial: {fileID: 0} + m_HaloStrength: 0.5 + m_FlareStrength: 1 + m_FlareFadeSpeed: 3 + m_HaloTexture: {fileID: 0} + m_SpotCookie: {fileID: 10001, guid: 0000000000000000e000000000000000, type: 0} + m_DefaultReflectionMode: 0 + m_DefaultReflectionResolution: 128 + m_ReflectionBounces: 1 + m_ReflectionIntensity: 1 + m_CustomReflection: {fileID: 0} + m_Sun: {fileID: 0} + m_IndirectSpecularColor: {r: 0, g: 0, b: 0, a: 1} +--- !u!157 &3 +LightmapSettings: + m_ObjectHideFlags: 0 + serializedVersion: 11 + m_GIWorkflowMode: 1 + m_GISettings: + serializedVersion: 2 + m_BounceScale: 1 + m_IndirectOutputScale: 1 + m_AlbedoBoost: 1 + m_TemporalCoherenceThreshold: 1 + m_EnvironmentLightingMode: 0 + m_EnableBakedLightmaps: 0 + m_EnableRealtimeLightmaps: 0 + m_LightmapEditorSettings: + serializedVersion: 9 + m_Resolution: 2 + m_BakeResolution: 40 + m_TextureWidth: 1024 + m_TextureHeight: 1024 + m_AO: 0 + m_AOMaxDistance: 1 + m_CompAOExponent: 1 + m_CompAOExponentDirect: 0 + m_Padding: 2 + m_LightmapParameters: {fileID: 0} + m_LightmapsBakeMode: 1 + m_TextureCompression: 1 + m_FinalGather: 0 + m_FinalGatherFiltering: 1 + m_FinalGatherRayCount: 256 + m_ReflectionCompression: 2 + m_MixedBakeMode: 2 + m_BakeBackend: 0 + m_PVRSampling: 1 + m_PVRDirectSampleCount: 32 + m_PVRSampleCount: 500 + m_PVRBounces: 2 + m_PVRFilterTypeDirect: 0 + m_PVRFilterTypeIndirect: 0 + m_PVRFilterTypeAO: 0 + m_PVRFilteringMode: 1 + m_PVRCulling: 1 + m_PVRFilteringGaussRadiusDirect: 1 + m_PVRFilteringGaussRadiusIndirect: 5 + m_PVRFilteringGaussRadiusAO: 2 + m_PVRFilteringAtrousPositionSigmaDirect: 0.5 + m_PVRFilteringAtrousPositionSigmaIndirect: 2 + m_PVRFilteringAtrousPositionSigmaAO: 1 + m_ShowResolutionOverlay: 1 + m_LightingDataAsset: {fileID: 0} + m_UseShadowmask: 1 +--- !u!196 &4 +NavMeshSettings: + serializedVersion: 2 + m_ObjectHideFlags: 0 + m_BuildSettings: + serializedVersion: 2 + agentTypeID: 0 + agentRadius: 0.5 + agentHeight: 2 + agentSlope: 45 + agentClimb: 0.4 + ledgeDropHeight: 0 + maxJumpAcrossDistance: 0 + minRegionArea: 2 + manualCellSize: 0 + cellSize: 0.16666667 + manualTileSize: 0 + tileSize: 256 + accuratePlacement: 0 + debug: + m_Flags: 0 + m_NavMeshData: {fileID: 0} +--- !u!1 &492081941 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 492081945} + - component: {fileID: 492081944} + - component: {fileID: 492081943} + - component: {fileID: 492081942} + m_Layer: 0 + m_Name: Main Camera + m_TagString: MainCamera + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!81 &492081942 +AudioListener: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 +--- !u!124 &492081943 +Behaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 +--- !u!20 &492081944 +Camera: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 + serializedVersion: 2 + m_ClearFlags: 1 + m_BackGroundColor: {r: 0.21933319, g: 0.21933319, b: 0.21933319, a: 0} + m_NormalizedViewPortRect: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + near clip plane: 0.3 + far clip plane: 1000 + field of view: 60 + orthographic: 1 + orthographic size: 5 + m_Depth: -1 + m_CullingMask: + serializedVersion: 2 + m_Bits: 4294967295 + m_RenderingPath: -1 + m_TargetTexture: {fileID: 0} + m_TargetDisplay: 0 + m_TargetEye: 3 + m_HDR: 1 + m_AllowMSAA: 1 + m_AllowDynamicResolution: 0 + m_ForceIntoRT: 0 + m_OcclusionCulling: 1 + m_StereoConvergence: 10 + m_StereoSeparation: 0.022 +--- !u!4 &492081945 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: -10} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: + - {fileID: 904015944} + m_Father: {fileID: 0} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} +--- !u!1 &871349752 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 871349756} + - component: {fileID: 871349755} + - component: {fileID: 871349754} + - component: {fileID: 871349753} + m_Layer: 5 + m_Name: Canvas + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!114 &871349753 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1301386320, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_IgnoreReversedGraphics: 1 + m_BlockingObjects: 0 + m_BlockingMask: + serializedVersion: 2 + m_Bits: 4294967295 +--- !u!114 &871349754 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1980459831, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_UiScaleMode: 0 + m_ReferencePixelsPerUnit: 100 + m_ScaleFactor: 1 + m_ReferenceResolution: {x: 800, y: 600} + m_ScreenMatchMode: 0 + m_MatchWidthOrHeight: 0 + m_PhysicalUnit: 3 + m_FallbackScreenDPI: 96 + m_DefaultSpriteDPI: 96 + m_DynamicPixelsPerUnit: 1 +--- !u!223 &871349755 +Canvas: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + serializedVersion: 3 + m_RenderMode: 0 + m_Camera: {fileID: 0} + m_PlaneDistance: 100 + m_PixelPerfect: 0 + m_ReceivesEvents: 1 + m_OverrideSorting: 0 + m_OverridePixelPerfect: 0 + m_SortingBucketNormalizedSize: 0 + m_AdditionalShaderChannelsFlag: 0 + m_SortingLayerID: 0 + m_SortingOrder: 0 + m_TargetDisplay: 0 +--- !u!224 &871349756 +RectTransform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 0, y: 0, z: 0} + m_Children: + - {fileID: 1726294324} + m_Father: {fileID: 0} + m_RootOrder: 1 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0, y: 0} + m_AnchorMax: {x: 0, y: 0} + m_AnchoredPosition: {x: 0, y: 0} + m_SizeDelta: {x: 0, y: 0} + m_Pivot: {x: 0, y: 0} +--- !u!1 &904015943 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 904015944} + - component: {fileID: 904015945} + m_Layer: 0 + m_Name: HelloTFLite + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!4 &904015944 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 904015943} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 492081945} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} +--- !u!114 &904015945 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 904015943} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 899510441e0ca4be0879d3055e467878, type: 3} + m_Name: + m_EditorClassIdentifier: + model: {fileID: 4900000, guid: adff4e1dbdba344c199ee4fe7e84457e, type: 3} + inputs: + - 1 + - 3 + - 7 + inferenceText: {fileID: 1726294325} +--- !u!1 &1726294323 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 1726294324} + - component: {fileID: 1726294326} + - component: {fileID: 1726294325} + m_Layer: 5 + m_Name: InferenceText + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!224 &1726294324 +RectTransform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} + m_LocalRotation: {x: -0, y: -0, z: -0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 871349756} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0.5, y: 0.5} + m_AnchorMax: {x: 0.5, y: 0.5} + m_AnchoredPosition: {x: 0, y: 25} + m_SizeDelta: {x: 450, y: 250} + m_Pivot: {x: 0.5, y: 0.5} +--- !u!114 &1726294325 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 708705254, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_Material: {fileID: 0} + m_Color: {r: 0.9338235, g: 0.9338235, b: 0.9338235, a: 1} + m_RaycastTarget: 1 + m_OnCullStateChanged: + m_PersistentCalls: + m_Calls: [] + m_TypeName: UnityEngine.UI.MaskableGraphic+CullStateChangedEvent, UnityEngine.UI, + Version=1.0.0.0, Culture=neutral, PublicKeyToken=null + m_FontData: + m_Font: {fileID: 10102, guid: 0000000000000000e000000000000000, type: 0} + m_FontSize: 35 + m_FontStyle: 0 + m_BestFit: 0 + m_MinSize: 2 + m_MaxSize: 40 + m_Alignment: 4 + m_AlignByGeometry: 0 + m_RichText: 1 + m_HorizontalOverflow: 0 + m_VerticalOverflow: 0 + m_LineSpacing: 1 + m_Text: 'Inference took 0.0153 ms + + Input: 1,3,7 + + Output: 3,9,21' +--- !u!222 &1726294326 +CanvasRenderer: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} +--- !u!1 &2026426602 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 2026426605} + - component: {fileID: 2026426604} + - component: {fileID: 2026426603} + m_Layer: 0 + m_Name: EventSystem + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!114 &2026426603 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1077351063, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_HorizontalAxis: Horizontal + m_VerticalAxis: Vertical + m_SubmitButton: Submit + m_CancelButton: Cancel + m_InputActionsPerSecond: 10 + m_RepeatDelay: 0.5 + m_ForceModuleActive: 0 +--- !u!114 &2026426604 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: -619905303, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_FirstSelected: {fileID: 0} + m_sendNavigationEvents: 1 + m_DragThreshold: 5 +--- !u!4 &2026426605 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 0} + m_RootOrder: 2 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta new file mode 100644 index 0000000000000000000000000000000000000000..e1e13efb66027b555f1d45c76fe58fe2103774a2 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: f8a8c37a396584bb7b21687f33d6d3f8 +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes new file mode 100644 index 0000000000000000000000000000000000000000..aef0fe3d82c9d92dc444076d3b46e05af1923f46 Binary files /dev/null and b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes differ diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta new file mode 100644 index 0000000000000000000000000000000000000000..ba24871413e06154afd0c0d5e2db83b7619d34a9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: adff4e1dbdba344c199ee4fe7e84457e +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta new file mode 100644 index 0000000000000000000000000000000000000000..28fde68b8b1346e88375dc7a8613270f0e2f2762 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: f7d1e2dec09b64acdb7b8f5aef9fcb44 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs new file mode 100644 index 0000000000000000000000000000000000000000..83291e61794819e7c57f69ed2be6ea40294e01da --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using TensorFlowLite; +using UnityEngine; +using UnityEngine.UI; + +/// +/// Simple example demonstrating use of the experimental C# bindings for TensorFlowLite. +/// +public class HelloTFLite : MonoBehaviour { + + [Tooltip("Configurable TFLite model.")] + public TextAsset model; + + [Tooltip("Configurable TFLite input tensor data.")] + public float[] inputs; + + [Tooltip("Target Text widget for display of inference execution.")] + public Text inferenceText; + + private Interpreter interpreter; + private float[] outputs; + + void Awake() { + // As the demo is extremely simple, there's no need to run at full frame-rate. + QualitySettings.vSyncCount = 0; + Application.targetFrameRate = 5; + } + + void Start () { + interpreter = new Interpreter(model.bytes); + Debug.LogFormat( + "InputCount: {0}, OutputCount: {1}", + interpreter.GetInputTensorCount(), + interpreter.GetOutputTensorCount()); + } + + void Update () { + if (inputs == null) { + return; + } + + if (outputs == null || outputs.Length != inputs.Length) { + interpreter.ResizeInputTensor(0, new int[]{inputs.Length}); + interpreter.AllocateTensors(); + outputs = new float[inputs.Length]; + } + + float startTimeSeconds = Time.realtimeSinceStartup; + interpreter.SetInputTensorData(0, inputs); + interpreter.Invoke(); + interpreter.GetOutputTensorData(0, outputs); + float inferenceTimeSeconds = Time.realtimeSinceStartup - startTimeSeconds; + + inferenceText.text = string.Format( + "Inference took {0:0.0000} ms\nInput(s): {1}\nOutput(s): {2}", + inferenceTimeSeconds * 1000.0, + ArrayToString(inputs), + ArrayToString(outputs)); + } + + void OnDestroy() { + interpreter.Dispose(); + } + + private static string ArrayToString(float[] values) { + return string.Join(",", values.Select(x => x.ToString()).ToArray()); + } +} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta new file mode 100644 index 0000000000000000000000000000000000000000..ba83f45084bb624e5e7777684b0fda98b4d46688 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 899510441e0ca4be0879d3055e467878 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta new file mode 100644 index 0000000000000000000000000000000000000000..bf5ce15c6a6932398d798d193b54f4ecfd8ba2d8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 16dad1655bcdc48f7b325a2a634b9c69 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta new file mode 100644 index 0000000000000000000000000000000000000000..22ed2c466bde1668595967f7a07f34a9193aaec8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: d70863368f8904d509a9b73d3a555914 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs new file mode 100644 index 0000000000000000000000000000000000000000..ab966bae2efb9431e2f9f35dc818d130aabd71f6 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +using System; +using System.Runtime.InteropServices; + +using TFL_Interpreter = System.IntPtr; +using TFL_Tensor = System.IntPtr; + +namespace TensorFlowLite +{ + /// + /// Simple C# bindings for the experimental TensorFlowLite C API. + /// + public class Interpreter : IDisposable + { + private const string TensorFlowLibrary = "tensorflowlite_c"; + + private TFL_Interpreter handle; + + public Interpreter(byte[] modelData) { + GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned); + IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject(); + handle = TFL_NewInterpreter(modelDataPtr, modelData.Length); + if (handle == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); + } + + ~Interpreter() { + Dispose(); + } + + public void Dispose() { + if (handle != IntPtr.Zero) TFL_DeleteInterpreter(handle); + handle = IntPtr.Zero; + } + + public void Invoke() { + ThrowIfError(TFL_InterpreterInvoke(handle)); + } + + public int GetInputTensorCount() { + return TFL_InterpreterGetInputTensorCount(handle); + } + + public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) { + GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned); + IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); + TFL_Tensor tensor = TFL_InterpreterGetInputTensor(handle, inputTensorIndex); + ThrowIfError(TFL_TensorCopyFromBuffer( + tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData))); + } + + public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) { + ThrowIfError(TFL_InterpreterResizeInputTensor( + handle, inputTensorIndex, inputTensorShape, inputTensorShape.Length)); + } + + public void AllocateTensors() { + ThrowIfError(TFL_InterpreterAllocateTensors(handle)); + } + + public int GetOutputTensorCount() { + return TFL_InterpreterGetOutputTensorCount(handle); + } + + public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) { + GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned); + IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); + TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(handle, outputTensorIndex); + ThrowIfError(TFL_TensorCopyToBuffer( + tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData))); + } + + private static void ThrowIfError(int resultCode) { + if (resultCode != 0) throw new Exception("TensorFlowLite operation failed."); + } + + #region Externs + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Interpreter TFL_NewInterpreter( + IntPtr model_data, + int model_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe void TFL_DeleteInterpreter(TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterGetInputTensorCount( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Tensor TFL_InterpreterGetInputTensor( + TFL_Interpreter interpreter, + int input_index); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterResizeInputTensor( + TFL_Interpreter interpreter, + int input_index, + int[] input_dims, + int input_dims_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterAllocateTensors( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterInvoke(TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterGetOutputTensorCount( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Tensor TFL_InterpreterGetOutputTensor( + TFL_Interpreter interpreter, + int output_index); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_TensorCopyFromBuffer( + TFL_Tensor tensor, + IntPtr input_data, + int input_data_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_TensorCopyToBuffer( + TFL_Tensor tensor, + IntPtr output_data, + int output_data_size); + + #endregion + } +} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta new file mode 100644 index 0000000000000000000000000000000000000000..5ec84ef7f70e9be45ff6292ed7a412fac35010de --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 0bbaf59e6ac914ed1b28174fb9008a09 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..da6112576a5ca4290108f6d4c731bd4c391e91d4 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset @@ -0,0 +1,17 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!11 &1 +AudioManager: + m_ObjectHideFlags: 0 + m_Volume: 1 + Rolloff Scale: 1 + Doppler Factor: 1 + Default Speaker Mode: 2 + m_SampleRate: 0 + m_DSPBufferSize: 0 + m_VirtualVoiceCount: 512 + m_RealVoiceCount: 32 + m_SpatializerPlugin: + m_AmbisonicDecoderPlugin: + m_DisableAudio: 0 + m_VirtualizeEffects: 1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..e7886b266a005f4d9d80f2fef8d1649dcfd3ed2b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset @@ -0,0 +1,6 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!236 &1 +ClusterInputManager: + m_ObjectHideFlags: 0 + m_Inputs: [] diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..78992f08c7ab7a4353c8a7d07cf1548174aaacbf --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset @@ -0,0 +1,29 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!55 &1 +PhysicsManager: + m_ObjectHideFlags: 0 + serializedVersion: 7 + m_Gravity: {x: 0, y: -9.81, z: 0} + m_DefaultMaterial: {fileID: 0} + m_BounceThreshold: 2 + m_SleepThreshold: 0.005 + m_DefaultContactOffset: 0.01 + m_DefaultSolverIterations: 6 + m_DefaultSolverVelocityIterations: 1 + m_QueriesHitBackfaces: 0 + m_QueriesHitTriggers: 1 + m_EnableAdaptiveForce: 0 + m_ClothInterCollisionDistance: 0 + m_ClothInterCollisionStiffness: 0 + m_ContactsGeneration: 1 + m_LayerCollisionMatrix: ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff + m_AutoSimulation: 1 + m_AutoSyncTransforms: 1 + m_ClothInterCollisionSettingsToggle: 0 + m_ContactPairsMode: 0 + m_BroadphaseType: 0 + m_WorldBounds: + m_Center: {x: 0, y: 0, z: 0} + m_Extent: {x: 250, y: 250, z: 250} + m_WorldSubdivisions: 8 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..6dc24f7dfdb697ad6f5d0a4ec5599bcd3cbd2f43 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset @@ -0,0 +1,7 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!1045 &1 +EditorBuildSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_Scenes: [] diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..fcd016402f97e4c009a16640517a6930ed615ef9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset @@ -0,0 +1,21 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!159 &1 +EditorSettings: + m_ObjectHideFlags: 0 + serializedVersion: 7 + m_ExternalVersionControlSupport: Visible Meta Files + m_SerializationMode: 2 + m_LineEndingsForNewScripts: 1 + m_DefaultBehaviorMode: 1 + m_SpritePackerMode: 4 + m_SpritePackerPaddingPower: 1 + m_EtcTextureCompressorBehavior: 1 + m_EtcTextureFastCompressor: 1 + m_EtcTextureNormalCompressor: 2 + m_EtcTextureBestCompressor: 4 + m_ProjectGenerationIncludedExtensions: txt;xml;fnt;cd;asmdef;rsp + m_ProjectGenerationRootNamespace: + m_UserGeneratedProjectSuffix: + m_CollabEditorSettings: + inProgressEnabled: 1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..a9bbfb02d1e7065b7d0e90609a3928d667933477 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset @@ -0,0 +1,64 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!30 &1 +GraphicsSettings: + m_ObjectHideFlags: 0 + serializedVersion: 12 + m_Deferred: + m_Mode: 1 + m_Shader: {fileID: 69, guid: 0000000000000000f000000000000000, type: 0} + m_DeferredReflections: + m_Mode: 1 + m_Shader: {fileID: 74, guid: 0000000000000000f000000000000000, type: 0} + m_ScreenSpaceShadows: + m_Mode: 1 + m_Shader: {fileID: 64, guid: 0000000000000000f000000000000000, type: 0} + m_LegacyDeferred: + m_Mode: 1 + m_Shader: {fileID: 63, guid: 0000000000000000f000000000000000, type: 0} + m_DepthNormals: + m_Mode: 1 + m_Shader: {fileID: 62, guid: 0000000000000000f000000000000000, type: 0} + m_MotionVectors: + m_Mode: 1 + m_Shader: {fileID: 75, guid: 0000000000000000f000000000000000, type: 0} + m_LightHalo: + m_Mode: 1 + m_Shader: {fileID: 105, guid: 0000000000000000f000000000000000, type: 0} + m_LensFlare: + m_Mode: 1 + m_Shader: {fileID: 102, guid: 0000000000000000f000000000000000, type: 0} + m_AlwaysIncludedShaders: + - {fileID: 7, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15104, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15105, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15106, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 10753, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 10770, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 17000, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 16000, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 16002, guid: 0000000000000000f000000000000000, type: 0} + m_PreloadedShaders: [] + m_SpritesDefaultMaterial: {fileID: 10754, guid: 0000000000000000f000000000000000, + type: 0} + m_CustomRenderPipeline: {fileID: 0} + m_TransparencySortMode: 0 + m_TransparencySortAxis: {x: 0, y: 0, z: 1} + m_DefaultRenderingPath: 1 + m_DefaultMobileRenderingPath: 1 + m_TierSettings: [] + m_LightmapStripping: 0 + m_FogStripping: 0 + m_InstancingStripping: 0 + m_LightmapKeepPlain: 1 + m_LightmapKeepDirCombined: 1 + m_LightmapKeepDynamicPlain: 1 + m_LightmapKeepDynamicDirCombined: 1 + m_LightmapKeepShadowMask: 1 + m_LightmapKeepSubtractive: 1 + m_FogKeepLinear: 1 + m_FogKeepExp: 1 + m_FogKeepExp2: 1 + m_AlbedoSwatchInfos: [] + m_LightsUseLinearIntensity: 0 + m_LightsUseColorTemperature: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..17c8f538e2152c0a0310b4870979eeecece2153c --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset @@ -0,0 +1,295 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!13 &1 +InputManager: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_Axes: + - serializedVersion: 3 + m_Name: Horizontal + descriptiveName: + descriptiveNegativeName: + negativeButton: left + positiveButton: right + altNegativeButton: a + altPositiveButton: d + gravity: 3 + dead: 0.001 + sensitivity: 3 + snap: 1 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Vertical + descriptiveName: + descriptiveNegativeName: + negativeButton: down + positiveButton: up + altNegativeButton: s + altPositiveButton: w + gravity: 3 + dead: 0.001 + sensitivity: 3 + snap: 1 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire1 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left ctrl + altNegativeButton: + altPositiveButton: mouse 0 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire2 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left alt + altNegativeButton: + altPositiveButton: mouse 1 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire3 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left shift + altNegativeButton: + altPositiveButton: mouse 2 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Jump + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: space + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse X + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse Y + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 1 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse ScrollWheel + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 2 + joyNum: 0 + - serializedVersion: 3 + m_Name: Horizontal + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0.19 + sensitivity: 1 + snap: 0 + invert: 0 + type: 2 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Vertical + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0.19 + sensitivity: 1 + snap: 0 + invert: 1 + type: 2 + axis: 1 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire1 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 0 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire2 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 1 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire3 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 2 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Jump + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 3 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Submit + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: return + altNegativeButton: + altPositiveButton: joystick button 0 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Submit + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: enter + altNegativeButton: + altPositiveButton: space + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Cancel + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: escape + altNegativeButton: + altPositiveButton: joystick button 1 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset new file mode 100644 index 0000000000000000000000000000000000000000..3b0b7c3d183abdd300112f56965916ef11667f54 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset @@ -0,0 +1,91 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!126 &1 +NavMeshProjectSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + areas: + - name: Walkable + cost: 1 + - name: Not Walkable + cost: 1 + - name: Jump + cost: 2 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + m_LastAgentTypeID: -887442657 + m_Settings: + - serializedVersion: 2 + agentTypeID: 0 + agentRadius: 0.5 + agentHeight: 2 + agentSlope: 45 + agentClimb: 0.75 + ledgeDropHeight: 0 + maxJumpAcrossDistance: 0 + minRegionArea: 2 + manualCellSize: 0 + cellSize: 0.16666667 + manualTileSize: 0 + tileSize: 256 + accuratePlacement: 0 + debug: + m_Flags: 0 + m_SettingNames: + - Humanoid diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..5dc6a831d9f2a11f08ed96571e0f602e3c3908b5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset @@ -0,0 +1,8 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!149 &1 +NetworkManager: + m_ObjectHideFlags: 0 + m_DebugLevel: 0 + m_Sendrate: 15 + m_AssetToPrefab: {} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..132ee6bc868f1aae138555dc139e054b0d1d8620 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset @@ -0,0 +1,37 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!19 &1 +Physics2DSettings: + m_ObjectHideFlags: 0 + serializedVersion: 3 + m_Gravity: {x: 0, y: -9.81} + m_DefaultMaterial: {fileID: 0} + m_VelocityIterations: 8 + m_PositionIterations: 3 + m_VelocityThreshold: 1 + m_MaxLinearCorrection: 0.2 + m_MaxAngularCorrection: 8 + m_MaxTranslationSpeed: 100 + m_MaxRotationSpeed: 360 + m_BaumgarteScale: 0.2 + m_BaumgarteTimeOfImpactScale: 0.75 + m_TimeToSleep: 0.5 + m_LinearSleepTolerance: 0.01 + m_AngularSleepTolerance: 2 + m_DefaultContactOffset: 0.01 + m_AutoSimulation: 1 + m_QueriesHitTriggers: 1 + m_QueriesStartInColliders: 1 + m_ChangeStopsCallbacks: 0 + m_CallbacksOnDisable: 1 + m_AutoSyncTransforms: 1 + m_AlwaysShowColliders: 0 + m_ShowColliderSleep: 1 + m_ShowColliderContacts: 0 + m_ShowColliderAABB: 0 + m_ContactArrowScale: 0.2 + m_ColliderAwakeColor: {r: 0.5686275, g: 0.95686275, b: 0.54509807, a: 0.7529412} + m_ColliderAsleepColor: {r: 0.5686275, g: 0.95686275, b: 0.54509807, a: 0.36078432} + m_ColliderContactColor: {r: 1, g: 0, b: 1, a: 0.6862745} + m_ColliderAABBColor: {r: 1, g: 1, b: 0, a: 0.2509804} + m_LayerCollisionMatrix: ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..3fbfab76c13c84f66a166c5dfe1d4552503350ff --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset @@ -0,0 +1,641 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!129 &1 +PlayerSettings: + m_ObjectHideFlags: 0 + serializedVersion: 14 + productGUID: a084943b991dd4597b140f4ce2b41c65 + AndroidProfiler: 0 + AndroidFilterTouchesWhenObscured: 0 + defaultScreenOrientation: 4 + targetDevice: 2 + useOnDemandResources: 0 + accelerometerFrequency: 60 + companyName: DefaultCompany + productName: TensorFlowLitePlugin + defaultCursor: {fileID: 0} + cursorHotspot: {x: 0, y: 0} + m_SplashScreenBackgroundColor: {r: 0.13725491, g: 0.12156863, b: 0.1254902, a: 1} + m_ShowUnitySplashScreen: 1 + m_ShowUnitySplashLogo: 1 + m_SplashScreenOverlayOpacity: 1 + m_SplashScreenAnimation: 1 + m_SplashScreenLogoStyle: 1 + m_SplashScreenDrawMode: 0 + m_SplashScreenBackgroundAnimationZoom: 1 + m_SplashScreenLogoAnimationZoom: 1 + m_SplashScreenBackgroundLandscapeAspect: 1 + m_SplashScreenBackgroundPortraitAspect: 1 + m_SplashScreenBackgroundLandscapeUvs: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + m_SplashScreenBackgroundPortraitUvs: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + m_SplashScreenLogos: [] + m_VirtualRealitySplashScreen: {fileID: 0} + m_HolographicTrackingLossScreen: {fileID: 0} + defaultScreenWidth: 1024 + defaultScreenHeight: 768 + defaultScreenWidthWeb: 960 + defaultScreenHeightWeb: 600 + m_StereoRenderingPath: 0 + m_ActiveColorSpace: 0 + m_MTRendering: 1 + m_StackTraceTypes: 010000000100000001000000010000000100000001000000 + iosShowActivityIndicatorOnLoading: -1 + androidShowActivityIndicatorOnLoading: -1 + tizenShowActivityIndicatorOnLoading: -1 + iosAppInBackgroundBehavior: 0 + displayResolutionDialog: 1 + iosAllowHTTPDownload: 1 + allowedAutorotateToPortrait: 1 + allowedAutorotateToPortraitUpsideDown: 1 + allowedAutorotateToLandscapeRight: 1 + allowedAutorotateToLandscapeLeft: 1 + useOSAutorotation: 1 + use32BitDisplayBuffer: 1 + preserveFramebufferAlpha: 0 + disableDepthAndStencilBuffers: 0 + androidBlitType: 0 + defaultIsFullScreen: 1 + defaultIsNativeResolution: 1 + macRetinaSupport: 1 + runInBackground: 0 + captureSingleScreen: 0 + muteOtherAudioSources: 0 + Prepare IOS For Recording: 0 + Force IOS Speakers When Recording: 0 + deferSystemGesturesMode: 0 + hideHomeButton: 0 + submitAnalytics: 1 + usePlayerLog: 1 + bakeCollisionMeshes: 0 + forceSingleInstance: 0 + resizableWindow: 0 + useMacAppStoreValidation: 0 + macAppStoreCategory: public.app-category.games + gpuSkinning: 0 + graphicsJobs: 0 + xboxPIXTextureCapture: 0 + xboxEnableAvatar: 0 + xboxEnableKinect: 0 + xboxEnableKinectAutoTracking: 0 + xboxEnableFitness: 0 + visibleInBackground: 1 + allowFullscreenSwitch: 1 + graphicsJobMode: 0 + macFullscreenMode: 2 + d3d11FullscreenMode: 1 + xboxSpeechDB: 0 + xboxEnableHeadOrientation: 0 + xboxEnableGuest: 0 + xboxEnablePIXSampling: 0 + metalFramebufferOnly: 0 + n3dsDisableStereoscopicView: 0 + n3dsEnableSharedListOpt: 1 + n3dsEnableVSync: 0 + xboxOneResolution: 0 + xboxOneSResolution: 0 + xboxOneXResolution: 3 + xboxOneMonoLoggingLevel: 0 + xboxOneLoggingLevel: 1 + xboxOneDisableEsram: 0 + xboxOnePresentImmediateThreshold: 0 + videoMemoryForVertexBuffers: 0 + psp2PowerMode: 0 + psp2AcquireBGM: 1 + wiiUTVResolution: 0 + wiiUGamePadMSAA: 1 + wiiUSupportsNunchuk: 0 + wiiUSupportsClassicController: 0 + wiiUSupportsBalanceBoard: 0 + wiiUSupportsMotionPlus: 0 + wiiUSupportsProController: 0 + wiiUAllowScreenCapture: 1 + wiiUControllerCount: 0 + m_SupportedAspectRatios: + 4:3: 1 + 5:4: 1 + 16:10: 1 + 16:9: 1 + Others: 1 + bundleVersion: 1.0 + preloadedAssets: [] + metroInputSource: 0 + wsaTransparentSwapchain: 0 + m_HolographicPauseOnTrackingLoss: 1 + xboxOneDisableKinectGpuReservation: 0 + xboxOneEnable7thCore: 0 + vrSettings: + cardboard: + depthFormat: 0 + enableTransitionView: 0 + daydream: + depthFormat: 0 + useSustainedPerformanceMode: 0 + enableVideoLayer: 0 + useProtectedVideoMemory: 0 + minimumSupportedHeadTracking: 0 + maximumSupportedHeadTracking: 1 + hololens: + depthFormat: 1 + depthBufferSharingEnabled: 0 + oculus: + sharedDepthBuffer: 0 + dashSupport: 0 + protectGraphicsMemory: 0 + useHDRDisplay: 0 + m_ColorGamuts: 00000000 + targetPixelDensity: 30 + resolutionScalingMode: 0 + androidSupportedAspectRatio: 1 + androidMaxAspectRatio: 2.1 + applicationIdentifier: {} + buildNumber: {} + AndroidBundleVersionCode: 1 + AndroidMinSdkVersion: 16 + AndroidTargetSdkVersion: 0 + AndroidPreferredInstallLocation: 1 + aotOptions: + stripEngineCode: 1 + iPhoneStrippingLevel: 0 + iPhoneScriptCallOptimization: 0 + ForceInternetPermission: 0 + ForceSDCardPermission: 0 + CreateWallpaper: 0 + APKExpansionFiles: 0 + keepLoadedShadersAlive: 0 + StripUnusedMeshComponents: 0 + VertexChannelCompressionMask: + serializedVersion: 2 + m_Bits: 238 + iPhoneSdkVersion: 988 + iOSTargetOSVersionString: 7.0 + tvOSSdkVersion: 0 + tvOSRequireExtendedGameController: 0 + tvOSTargetOSVersionString: 9.0 + uIPrerenderedIcon: 0 + uIRequiresPersistentWiFi: 0 + uIRequiresFullScreen: 1 + uIStatusBarHidden: 1 + uIExitOnSuspend: 0 + uIStatusBarStyle: 0 + iPhoneSplashScreen: {fileID: 0} + iPhoneHighResSplashScreen: {fileID: 0} + iPhoneTallHighResSplashScreen: {fileID: 0} + iPhone47inSplashScreen: {fileID: 0} + iPhone55inPortraitSplashScreen: {fileID: 0} + iPhone55inLandscapeSplashScreen: {fileID: 0} + iPhone58inPortraitSplashScreen: {fileID: 0} + iPhone58inLandscapeSplashScreen: {fileID: 0} + iPadPortraitSplashScreen: {fileID: 0} + iPadHighResPortraitSplashScreen: {fileID: 0} + iPadLandscapeSplashScreen: {fileID: 0} + iPadHighResLandscapeSplashScreen: {fileID: 0} + appleTVSplashScreen: {fileID: 0} + appleTVSplashScreen2x: {fileID: 0} + tvOSSmallIconLayers: [] + tvOSSmallIconLayers2x: [] + tvOSLargeIconLayers: [] + tvOSTopShelfImageLayers: [] + tvOSTopShelfImageLayers2x: [] + tvOSTopShelfImageWideLayers: [] + tvOSTopShelfImageWideLayers2x: [] + iOSLaunchScreenType: 0 + iOSLaunchScreenPortrait: {fileID: 0} + iOSLaunchScreenLandscape: {fileID: 0} + iOSLaunchScreenBackgroundColor: + serializedVersion: 2 + rgba: 0 + iOSLaunchScreenFillPct: 100 + iOSLaunchScreenSize: 100 + iOSLaunchScreenCustomXibPath: + iOSLaunchScreeniPadType: 0 + iOSLaunchScreeniPadImage: {fileID: 0} + iOSLaunchScreeniPadBackgroundColor: + serializedVersion: 2 + rgba: 0 + iOSLaunchScreeniPadFillPct: 100 + iOSLaunchScreeniPadSize: 100 + iOSLaunchScreeniPadCustomXibPath: + iOSUseLaunchScreenStoryboard: 0 + iOSLaunchScreenCustomStoryboardPath: + iOSDeviceRequirements: [] + iOSURLSchemes: [] + iOSBackgroundModes: 0 + iOSMetalForceHardShadows: 0 + metalEditorSupport: 1 + metalAPIValidation: 1 + iOSRenderExtraFrameOnPause: 0 + appleDeveloperTeamID: + iOSManualSigningProvisioningProfileID: + tvOSManualSigningProvisioningProfileID: + appleEnableAutomaticSigning: 0 + clonedFromGUID: 00000000000000000000000000000000 + AndroidTargetDevice: 0 + AndroidSplashScreenScale: 0 + androidSplashScreen: {fileID: 0} + AndroidKeystoreName: + AndroidKeyaliasName: + AndroidTVCompatibility: 1 + AndroidIsGame: 1 + AndroidEnableTango: 0 + androidEnableBanner: 1 + androidUseLowAccuracyLocation: 0 + m_AndroidBanners: + - width: 320 + height: 180 + banner: {fileID: 0} + androidGamepadSupportLevel: 0 + resolutionDialogBanner: {fileID: 0} + m_BuildTargetIcons: [] + m_BuildTargetBatching: [] + m_BuildTargetGraphicsAPIs: [] + m_BuildTargetVRSettings: [] + m_BuildTargetEnableVuforiaSettings: [] + openGLRequireES31: 0 + openGLRequireES31AEP: 0 + m_TemplateCustomTags: {} + mobileMTRendering: + Android: 1 + iPhone: 1 + tvOS: 1 + m_BuildTargetGroupLightmapEncodingQuality: [] + wiiUTitleID: 0005000011000000 + wiiUGroupID: 00010000 + wiiUCommonSaveSize: 4096 + wiiUAccountSaveSize: 2048 + wiiUOlvAccessKey: 0 + wiiUTinCode: 0 + wiiUJoinGameId: 0 + wiiUJoinGameModeMask: 0000000000000000 + wiiUCommonBossSize: 0 + wiiUAccountBossSize: 0 + wiiUAddOnUniqueIDs: [] + wiiUMainThreadStackSize: 3072 + wiiULoaderThreadStackSize: 1024 + wiiUSystemHeapSize: 128 + wiiUTVStartupScreen: {fileID: 0} + wiiUGamePadStartupScreen: {fileID: 0} + wiiUDrcBufferDisabled: 0 + wiiUProfilerLibPath: + playModeTestRunnerEnabled: 0 + actionOnDotNetUnhandledException: 1 + enableInternalProfiler: 0 + logObjCUncaughtExceptions: 1 + enableCrashReportAPI: 0 + cameraUsageDescription: + locationUsageDescription: + microphoneUsageDescription: + switchNetLibKey: + switchSocketMemoryPoolSize: 6144 + switchSocketAllocatorPoolSize: 128 + switchSocketConcurrencyLimit: 14 + switchScreenResolutionBehavior: 2 + switchUseCPUProfiler: 0 + switchApplicationID: 0x01004b9000490000 + switchNSODependencies: + switchTitleNames_0: + switchTitleNames_1: + switchTitleNames_2: + switchTitleNames_3: + switchTitleNames_4: + switchTitleNames_5: + switchTitleNames_6: + switchTitleNames_7: + switchTitleNames_8: + switchTitleNames_9: + switchTitleNames_10: + switchTitleNames_11: + switchTitleNames_12: + switchTitleNames_13: + switchTitleNames_14: + switchPublisherNames_0: + switchPublisherNames_1: + switchPublisherNames_2: + switchPublisherNames_3: + switchPublisherNames_4: + switchPublisherNames_5: + switchPublisherNames_6: + switchPublisherNames_7: + switchPublisherNames_8: + switchPublisherNames_9: + switchPublisherNames_10: + switchPublisherNames_11: + switchPublisherNames_12: + switchPublisherNames_13: + switchPublisherNames_14: + switchIcons_0: {fileID: 0} + switchIcons_1: {fileID: 0} + switchIcons_2: {fileID: 0} + switchIcons_3: {fileID: 0} + switchIcons_4: {fileID: 0} + switchIcons_5: {fileID: 0} + switchIcons_6: {fileID: 0} + switchIcons_7: {fileID: 0} + switchIcons_8: {fileID: 0} + switchIcons_9: {fileID: 0} + switchIcons_10: {fileID: 0} + switchIcons_11: {fileID: 0} + switchIcons_12: {fileID: 0} + switchIcons_13: {fileID: 0} + switchIcons_14: {fileID: 0} + switchSmallIcons_0: {fileID: 0} + switchSmallIcons_1: {fileID: 0} + switchSmallIcons_2: {fileID: 0} + switchSmallIcons_3: {fileID: 0} + switchSmallIcons_4: {fileID: 0} + switchSmallIcons_5: {fileID: 0} + switchSmallIcons_6: {fileID: 0} + switchSmallIcons_7: {fileID: 0} + switchSmallIcons_8: {fileID: 0} + switchSmallIcons_9: {fileID: 0} + switchSmallIcons_10: {fileID: 0} + switchSmallIcons_11: {fileID: 0} + switchSmallIcons_12: {fileID: 0} + switchSmallIcons_13: {fileID: 0} + switchSmallIcons_14: {fileID: 0} + switchManualHTML: + switchAccessibleURLs: + switchLegalInformation: + switchMainThreadStackSize: 1048576 + switchPresenceGroupId: + switchLogoHandling: 0 + switchReleaseVersion: 0 + switchDisplayVersion: 1.0.0 + switchStartupUserAccount: 0 + switchTouchScreenUsage: 0 + switchSupportedLanguagesMask: 0 + switchLogoType: 0 + switchApplicationErrorCodeCategory: + switchUserAccountSaveDataSize: 0 + switchUserAccountSaveDataJournalSize: 0 + switchApplicationAttribute: 0 + switchCardSpecSize: -1 + switchCardSpecClock: -1 + switchRatingsMask: 0 + switchRatingsInt_0: 0 + switchRatingsInt_1: 0 + switchRatingsInt_2: 0 + switchRatingsInt_3: 0 + switchRatingsInt_4: 0 + switchRatingsInt_5: 0 + switchRatingsInt_6: 0 + switchRatingsInt_7: 0 + switchRatingsInt_8: 0 + switchRatingsInt_9: 0 + switchRatingsInt_10: 0 + switchRatingsInt_11: 0 + switchLocalCommunicationIds_0: + switchLocalCommunicationIds_1: + switchLocalCommunicationIds_2: + switchLocalCommunicationIds_3: + switchLocalCommunicationIds_4: + switchLocalCommunicationIds_5: + switchLocalCommunicationIds_6: + switchLocalCommunicationIds_7: + switchParentalControl: 0 + switchAllowsScreenshot: 1 + switchAllowsVideoCapturing: 1 + switchAllowsRuntimeAddOnContentInstall: 0 + switchDataLossConfirmation: 0 + switchSupportedNpadStyles: 3 + switchSocketConfigEnabled: 0 + switchTcpInitialSendBufferSize: 32 + switchTcpInitialReceiveBufferSize: 64 + switchTcpAutoSendBufferSizeMax: 256 + switchTcpAutoReceiveBufferSizeMax: 256 + switchUdpSendBufferSize: 9 + switchUdpReceiveBufferSize: 42 + switchSocketBufferEfficiency: 4 + switchSocketInitializeEnabled: 1 + switchNetworkInterfaceManagerInitializeEnabled: 1 + switchPlayerConnectionEnabled: 1 + ps4NPAgeRating: 12 + ps4NPTitleSecret: + ps4NPTrophyPackPath: + ps4ParentalLevel: 11 + ps4ContentID: ED1633-NPXX51362_00-0000000000000000 + ps4Category: 0 + ps4MasterVersion: 01.00 + ps4AppVersion: 01.00 + ps4AppType: 0 + ps4ParamSfxPath: + ps4VideoOutPixelFormat: 0 + ps4VideoOutInitialWidth: 1920 + ps4VideoOutBaseModeInitialWidth: 1920 + ps4VideoOutReprojectionRate: 60 + ps4PronunciationXMLPath: + ps4PronunciationSIGPath: + ps4BackgroundImagePath: + ps4StartupImagePath: + ps4StartupImagesFolder: + ps4IconImagesFolder: + ps4SaveDataImagePath: + ps4SdkOverride: + ps4BGMPath: + ps4ShareFilePath: + ps4ShareOverlayImagePath: + ps4PrivacyGuardImagePath: + ps4NPtitleDatPath: + ps4RemotePlayKeyAssignment: -1 + ps4RemotePlayKeyMappingDir: + ps4PlayTogetherPlayerCount: 0 + ps4EnterButtonAssignment: 1 + ps4ApplicationParam1: 0 + ps4ApplicationParam2: 0 + ps4ApplicationParam3: 0 + ps4ApplicationParam4: 0 + ps4DownloadDataSize: 0 + ps4GarlicHeapSize: 2048 + ps4ProGarlicHeapSize: 2560 + ps4Passcode: d3hjjul8UhK6ZnQCEBYYQPozR9sQV066 + ps4pnSessions: 1 + ps4pnPresence: 1 + ps4pnFriends: 1 + ps4pnGameCustomData: 1 + playerPrefsSupport: 0 + restrictedAudioUsageRights: 0 + ps4UseResolutionFallback: 0 + ps4ReprojectionSupport: 0 + ps4UseAudio3dBackend: 0 + ps4SocialScreenEnabled: 0 + ps4ScriptOptimizationLevel: 0 + ps4Audio3dVirtualSpeakerCount: 14 + ps4attribCpuUsage: 0 + ps4PatchPkgPath: + ps4PatchLatestPkgPath: + ps4PatchChangeinfoPath: + ps4PatchDayOne: 0 + ps4attribUserManagement: 0 + ps4attribMoveSupport: 0 + ps4attrib3DSupport: 0 + ps4attribShareSupport: 0 + ps4attribExclusiveVR: 0 + ps4disableAutoHideSplash: 0 + ps4videoRecordingFeaturesUsed: 0 + ps4contentSearchFeaturesUsed: 0 + ps4attribEyeToEyeDistanceSettingVR: 0 + ps4IncludedModules: [] + monoEnv: + psp2Splashimage: {fileID: 0} + psp2NPTrophyPackPath: + psp2NPSupportGBMorGJP: 0 + psp2NPAgeRating: 12 + psp2NPTitleDatPath: + psp2NPCommsID: + psp2NPCommunicationsID: + psp2NPCommsPassphrase: + psp2NPCommsSig: + psp2ParamSfxPath: + psp2ManualPath: + psp2LiveAreaGatePath: + psp2LiveAreaBackroundPath: + psp2LiveAreaPath: + psp2LiveAreaTrialPath: + psp2PatchChangeInfoPath: + psp2PatchOriginalPackage: + psp2PackagePassword: 3onkgZsAECEn0fzCoWiCtWCKe4l74pE5 + psp2KeystoneFile: + psp2MemoryExpansionMode: 0 + psp2DRMType: 0 + psp2StorageType: 0 + psp2MediaCapacity: 0 + psp2DLCConfigPath: + psp2ThumbnailPath: + psp2BackgroundPath: + psp2SoundPath: + psp2TrophyCommId: + psp2TrophyPackagePath: + psp2PackagedResourcesPath: + psp2SaveDataQuota: 10240 + psp2ParentalLevel: 1 + psp2ShortTitle: Not Set + psp2ContentID: IV0000-ABCD12345_00-0123456789ABCDEF + psp2Category: 0 + psp2MasterVersion: 01.00 + psp2AppVersion: 01.00 + psp2TVBootMode: 0 + psp2EnterButtonAssignment: 2 + psp2TVDisableEmu: 0 + psp2AllowTwitterDialog: 1 + psp2Upgradable: 0 + psp2HealthWarning: 0 + psp2UseLibLocation: 0 + psp2InfoBarOnStartup: 0 + psp2InfoBarColor: 0 + psp2ScriptOptimizationLevel: 0 + psmSplashimage: {fileID: 0} + splashScreenBackgroundSourceLandscape: {fileID: 0} + splashScreenBackgroundSourcePortrait: {fileID: 0} + spritePackerPolicy: + webGLMemorySize: 256 + webGLExceptionSupport: 1 + webGLNameFilesAsHashes: 0 + webGLDataCaching: 0 + webGLDebugSymbols: 0 + webGLEmscriptenArgs: + webGLModulesDirectory: + webGLTemplate: APPLICATION:Default + webGLAnalyzeBuildSize: 0 + webGLUseEmbeddedResources: 0 + webGLUseWasm: 0 + webGLCompressionFormat: 1 + scriptingDefineSymbols: {} + platformArchitecture: {} + scriptingBackend: {} + incrementalIl2cppBuild: {} + additionalIl2CppArgs: + scriptingRuntimeVersion: 0 + apiCompatibilityLevelPerPlatform: {} + m_RenderingPath: 1 + m_MobileRenderingPath: 1 + metroPackageName: TensorFlowLitePlugin + metroPackageVersion: + metroCertificatePath: + metroCertificatePassword: + metroCertificateSubject: + metroCertificateIssuer: + metroCertificateNotAfter: 0000000000000000 + metroApplicationDescription: TensorFlowLitePlugin + wsaImages: {} + metroTileShortName: + metroCommandLineArgsFile: + metroTileShowName: 0 + metroMediumTileShowName: 0 + metroLargeTileShowName: 0 + metroWideTileShowName: 0 + metroDefaultTileSize: 1 + metroTileForegroundText: 2 + metroTileBackgroundColor: {r: 0.13333334, g: 0.17254902, b: 0.21568628, a: 0} + metroSplashScreenBackgroundColor: {r: 0.12941177, g: 0.17254902, b: 0.21568628, + a: 1} + metroSplashScreenUseBackgroundColor: 0 + platformCapabilities: {} + metroFTAName: + metroFTAFileTypes: [] + metroProtocolName: + metroCompilationOverrides: 1 + tizenProductDescription: + tizenProductURL: + tizenSigningProfileName: + tizenGPSPermissions: 0 + tizenMicrophonePermissions: 0 + tizenDeploymentTarget: + tizenDeploymentTargetType: -1 + tizenMinOSVersion: 1 + n3dsUseExtSaveData: 0 + n3dsCompressStaticMem: 1 + n3dsExtSaveDataNumber: 0x12345 + n3dsStackSize: 131072 + n3dsTargetPlatform: 2 + n3dsRegion: 7 + n3dsMediaSize: 0 + n3dsLogoStyle: 3 + n3dsTitle: GameName + n3dsProductCode: + n3dsApplicationId: 0xFF3FF + XboxOneProductId: + XboxOneUpdateKey: + XboxOneSandboxId: + XboxOneContentId: + XboxOneTitleId: + XboxOneSCId: + XboxOneGameOsOverridePath: + XboxOnePackagingOverridePath: + XboxOneAppManifestOverridePath: + XboxOnePackageEncryption: 0 + XboxOnePackageUpdateGranularity: 2 + XboxOneDescription: + XboxOneLanguage: + - enus + XboxOneCapability: [] + XboxOneGameRating: {} + XboxOneIsContentPackage: 0 + XboxOneEnableGPUVariability: 0 + XboxOneSockets: {} + XboxOneSplashScreen: {fileID: 0} + XboxOneAllowedProductIds: [] + XboxOnePersistentLocalStorageSize: 0 + XboxOneXTitleMemory: 8 + xboxOneScriptCompiler: 0 + vrEditorSettings: + daydream: + daydreamIconForeground: {fileID: 0} + daydreamIconBackground: {fileID: 0} + cloudServicesEnabled: {} + facebookSdkVersion: 7.9.4 + apiCompatibilityLevel: 2 + cloudProjectId: + projectName: + organizationId: + cloudEnabled: 0 + enableNativePlatformBackendsForNewInputSystem: 0 + disableOldInputManagerSupport: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a9cfb61ab55abc2f0d09b0225a802ef8122eaaf --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt @@ -0,0 +1 @@ +m_EditorVersion: 2017.4.6f1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..05daac3c4922feef068af19efa921fcbb476afde --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset @@ -0,0 +1,191 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!47 &1 +QualitySettings: + m_ObjectHideFlags: 0 + serializedVersion: 5 + m_CurrentQuality: 5 + m_QualitySettings: + - serializedVersion: 2 + name: Very Low + pixelLightCount: 0 + shadows: 0 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 15 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 1 + textureQuality: 1 + anisotropicTextures: 0 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 0 + lodBias: 0.3 + maximumLODLevel: 0 + particleRaycastBudget: 4 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Low + pixelLightCount: 0 + shadows: 0 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 20 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 0 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 0 + lodBias: 0.4 + maximumLODLevel: 0 + particleRaycastBudget: 16 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Medium + pixelLightCount: 1 + shadows: 1 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 20 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 1 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 1 + lodBias: 0.7 + maximumLODLevel: 0 + particleRaycastBudget: 64 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: High + pixelLightCount: 2 + shadows: 2 + shadowResolution: 1 + shadowProjection: 1 + shadowCascades: 2 + shadowDistance: 40 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 1 + antiAliasing: 0 + softParticles: 0 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 1 + maximumLODLevel: 0 + particleRaycastBudget: 256 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Very High + pixelLightCount: 3 + shadows: 2 + shadowResolution: 2 + shadowProjection: 1 + shadowCascades: 2 + shadowDistance: 70 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 4 + textureQuality: 0 + anisotropicTextures: 2 + antiAliasing: 2 + softParticles: 1 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 1.5 + maximumLODLevel: 0 + particleRaycastBudget: 1024 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Ultra + pixelLightCount: 4 + shadows: 2 + shadowResolution: 2 + shadowProjection: 1 + shadowCascades: 4 + shadowDistance: 150 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 4 + textureQuality: 0 + anisotropicTextures: 2 + antiAliasing: 2 + softParticles: 1 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 2 + maximumLODLevel: 0 + particleRaycastBudget: 4096 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + m_PerPlatformDefaultQuality: + Android: 2 + Nintendo 3DS: 5 + Nintendo Switch: 5 + PS4: 5 + PSM: 5 + PSP2: 2 + Standalone: 5 + Tizen: 2 + WebGL: 3 + WiiU: 5 + Windows Store Apps: 5 + XboxOne: 5 + iPhone: 2 + tvOS: 2 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..1c92a7840ec11895c76785f65d949a3d20d53355 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset @@ -0,0 +1,43 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!78 &1 +TagManager: + serializedVersion: 2 + tags: [] + layers: + - Default + - TransparentFX + - Ignore Raycast + - + - Water + - UI + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + m_SortingLayers: + - name: Default + uniqueID: 0 + locked: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..558a017e1f50b2db73414a1abad3c033922774f8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset @@ -0,0 +1,9 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!5 &1 +TimeManager: + m_ObjectHideFlags: 0 + Fixed Timestep: 0.02 + Maximum Allowed Timestep: 0.33333334 + m_TimeScale: 1 + Maximum Particle Timestep: 0.03 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..3da14d5baf1fa24df1746c3ce9d969eda3a9c59d --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset @@ -0,0 +1,34 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!310 &1 +UnityConnectSettings: + m_ObjectHideFlags: 0 + m_Enabled: 0 + m_TestMode: 0 + m_TestEventUrl: + m_TestConfigUrl: + m_TestInitMode: 0 + CrashReportingSettings: + m_EventUrl: https://perf-events.cloud.unity3d.com/api/events/crashes + m_NativeEventUrl: https://perf-events.cloud.unity3d.com/symbolicate + m_Enabled: 0 + m_CaptureEditorExceptions: 1 + UnityPurchasingSettings: + m_Enabled: 0 + m_TestMode: 0 + UnityAnalyticsSettings: + m_Enabled: 0 + m_InitializeOnStartup: 1 + m_TestMode: 0 + m_TestEventUrl: + m_TestConfigUrl: + UnityAdsSettings: + m_Enabled: 0 + m_InitializeOnStartup: 1 + m_TestMode: 0 + m_IosGameId: + m_AndroidGameId: + m_GameIds: {} + m_GameId: + PerformanceReportingSettings: + m_Enabled: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f480c49cd050de2192e9673f72c9e4d5c3c6ceff --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md @@ -0,0 +1,29 @@ +# TF Lite Experimental Unity Plugin + +This directory contains an experimental sample Unity (2017) Plugin, based on +the experimental TF Lite C API. The sample demonstrates running inference within +Unity by way of a C# `Interpreter` wrapper. + +Note that the native TF Lite plugin(s) *must* be built before using the Unity +Plugin, and placed in Assets/TensorFlowLite/SDK/Plugins/. For the editor (note +that this has only been tested on Linux; the syntax may differ on Mac/Windows): + +```sh +bazel build -c opt --cxxopt=--std=c++11 \ + //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so +``` + +and for Android: + +```sh +bazel build -c opt --cxxopt=--std=c++11 \ + --crosstool_top=//external:android/crosstool \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cpu=armeabi-v7a \ + //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so +``` + +If you encounter issues with native plugin discovery on Mac ("Darwin") +platforms, try renaming `libtensorflowlite_c.so` to `tensorflowlite_c.bundle`. +Similarly, on Windows you'll likely need to rename `libtensorflowlite_c.so` to +`tensorflowlite_c.dll`. diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json new file mode 100644 index 0000000000000000000000000000000000000000..526aca60573f334a6b6bd536fa5be9c26d678e0f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json @@ -0,0 +1,4 @@ +{ + "dependencies": { + } +} diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9c06c4ebd958294586dbb1fde5040a0d328954ac --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/BUILD @@ -0,0 +1,84 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# ctc support classes imported directly from TensorFlow. +cc_library( + name = "ctc_utils", + hdrs = [ + "ctc_beam_entry.h", + "ctc_beam_scorer.h", + "ctc_beam_search.h", + "ctc_decoder.h", + "ctc_loss_util.h", + ], + deps = [ + ":top_n", + "//tensorflow/contrib/lite/kernels/internal:types", + "//third_party/eigen3", + ], +) + +# top_n support classes imported directly from TensorFlow. +cc_library( + name = "top_n", + hdrs = [ + "top_n.h", + ], + deps = [ + "//tensorflow/contrib/lite/kernels/internal:types", + ], +) + +cc_library( + name = "experimental_ops", + srcs = [ + "ctc_beam_search_decoder.cc", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":ctc_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", + "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/contrib/lite/kernels/internal:optimized_base", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "ctc_beam_search_decoder_test", + size = "small", + srcs = ["ctc_beam_search_decoder_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":experimental_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h new file mode 100644 index 0000000000000000000000000000000000000000..a60ff2a1c53f1b3f9f490ab5cf2bc429ba09dff0 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h @@ -0,0 +1,150 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_beam_entry.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ + +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +// The ctc_beam_search namespace holds several classes meant to be accessed only +// in case of extending the CTCBeamSearch decoder to allow custom scoring +// functions. +// +// BeamEntry is exposed through template arguments BeamScorer and BeamComparer +// of CTCBeamSearch (ctc_beam_search.h). +namespace ctc_beam_search { + +struct EmptyBeamState {}; + +struct BeamProbability { + BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {} + void Reset() { + total = kLogZero; + blank = kLogZero; + label = kLogZero; + } + float total; + float blank; + float label; +}; + +template +class BeamRoot; + +template +struct BeamEntry { + // BeamRoot::AddEntry() serves as the factory method. + friend BeamEntry* BeamRoot::AddEntry( + BeamEntry* p, int l); + inline bool Active() const { return newp.total != kLogZero; } + // Return the child at the given index, or construct a new one in-place if + // none was found. + BeamEntry& GetChild(int ind) { + auto entry = children.emplace(ind, nullptr); + auto& child_entry = entry.first->second; + // If this is a new child, populate the BeamEntry*. + if (entry.second) { + child_entry = beam_root->AddEntry(this, ind); + } + return *child_entry; + } + std::vector LabelSeq(bool merge_repeated) const { + std::vector labels; + int prev_label = -1; + const BeamEntry* c = this; + while (c->parent != nullptr) { // Checking c->parent to skip root leaf. + if (!merge_repeated || c->label != prev_label) { + labels.push_back(c->label); + } + prev_label = c->label; + c = c->parent; + } + std::reverse(labels.begin(), labels.end()); + return labels; + } + + BeamEntry* parent; + int label; + // All instances of child BeamEntry are owned by *beam_root. + std::unordered_map*> children; + BeamProbability oldp; + BeamProbability newp; + CTCBeamState state; + + private: + // Constructor giving parent, label, and the beam_root. + // The object pointed to by p cannot be copied and should not be moved, + // otherwise parent will become invalid. + // This private constructor is only called through the factory method + // BeamRoot::AddEntry(). + BeamEntry(BeamEntry* p, int l, BeamRoot* beam_root) + : parent(p), label(l), beam_root(beam_root) {} + BeamRoot* beam_root; + + BeamEntry(const BeamEntry&) = delete; + void operator=(const BeamEntry&) = delete; +}; + +// This class owns all instances of BeamEntry. This is used to avoid recursive +// destructor call during destruction. +template +class BeamRoot { + public: + BeamRoot(BeamEntry* p, int l) { root_entry_ = AddEntry(p, l); } + BeamRoot(const BeamRoot&) = delete; + BeamRoot& operator=(const BeamRoot&) = delete; + + BeamEntry* AddEntry(BeamEntry* p, int l) { + auto* new_entry = new BeamEntry(p, l, this); + beam_entries_.emplace_back(new_entry); + return new_entry; + } + BeamEntry* RootEntry() const { return root_entry_; } + + private: + BeamEntry* root_entry_ = nullptr; + std::vector>> beam_entries_; +}; + +// BeamComparer is the default beam comparer provided in CTCBeamSearch. +template +class BeamComparer { + public: + virtual ~BeamComparer() {} + virtual bool inline operator()(const BeamEntry* a, + const BeamEntry* b) const { + return a->newp.total > b->newp.total; + } +}; + +} // namespace ctc_beam_search + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h new file mode 100644 index 0000000000000000000000000000000000000000..ec60e26257b0f4126e7a7abed6a663abe277ef12 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Collection of scoring classes that can be extended and provided to the +// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a +// language model). +// +// To build a custom scorer extend and implement the pure virtual methods from +// BeamScorerInterface. The default CTC decoding behavior is implemented +// through BaseBeamScorer. + +// Copied from tensorflow/core/util/ctc/ctc_beam_scorer.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ + +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +// Base implementation of a beam scorer used by default by the decoder that can +// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex +// scoring is required. Its main purpose is to provide a thin layer for +// integrating language model scoring easily. +template +class BaseBeamScorer { + public: + virtual ~BaseBeamScorer() {} + // State initialization. + virtual void InitializeState(CTCBeamState* root) const {} + // ExpandState is called when expanding a beam to one of its children. + // Called at most once per child beam. In the simplest case, no state + // expansion is done. + virtual void ExpandState(const CTCBeamState& from_state, int from_label, + CTCBeamState* to_state, int to_label) const {} + // ExpandStateEnd is called after decoding has finished. Its purpose is to + // allow a final scoring of the beam in its current state, before resorting + // and retrieving the TopN requested candidates. Called at most once per beam. + virtual void ExpandStateEnd(CTCBeamState* state) const {} + // GetStateExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandState. The score is + // multiplied (log-addition) with the input score at the current step from + // the network. + // + // The score returned should be a log-probability. In the simplest case, as + // there's no state expansion logic, the expansion score is zero. + virtual float GetStateExpansionScore(const CTCBeamState& state, + float previous_score) const { + return previous_score; + } + // GetStateEndExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandStateEnd. The score is + // multiplied (log-addition) with the final probability of the beam. + // + // The score returned should be a log-probability. + virtual float GetStateEndExpansionScore(const CTCBeamState& state) const { + return 0; + } +}; + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h new file mode 100644 index 0000000000000000000000000000000000000000..c658e43092519ba29d880a670a890af148230091 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h @@ -0,0 +1,420 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_beam_search.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ + +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" +#include "tensorflow/contrib/lite/experimental/kernels/top_n.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +template > +class CTCBeamSearchDecoder : public CTCDecoder { + // Beam Search + // + // Example (GravesTh Fig. 7.5): + // a - + // P = [ 0.3 0.7 ] t = 0 + // [ 0.4 0.6 ] t = 1 + // + // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42 + // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58 + // + // In this case, Best Path decoding is suboptimal. + // + // For Beam Search, we use the following main recurrence relations: + // + // Relation 1: + // ---------------------------------------------------------- Eq. 1 + // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7) + // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7)) + // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and + // updated recursively in the beam entry. + // + // Relation 2: + // ---------------------------------------------------------- Eq. 2 + // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3) + // for ? in a, b, d, ..., (not including c or the blank index), + // and the recurrence starts from the beam entry for P(l=abc @ t=2). + // + // For this case, the length of the new sequence equals t+1 (t + // starts at 0). This special case can be calculated as: + // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) + // but we calculate it recursively for speed purposes. + typedef ctc_beam_search::BeamEntry BeamEntry; + typedef ctc_beam_search::BeamRoot BeamRoot; + typedef ctc_beam_search::BeamProbability BeamProbability; + + public: + typedef BaseBeamScorer DefaultBeamScorer; + + // The beam search decoder is constructed specifying the beam_width (number of + // candidates to keep at each decoding timestep) and a beam scorer (used for + // custom scoring, for example enabling the use of a language model). + // The ownership of the scorer remains with the caller. The default + // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the + // standard beam search. + CTCBeamSearchDecoder(int num_classes, int beam_width, + BaseBeamScorer* scorer, int batch_size = 1, + bool merge_repeated = false) + : CTCDecoder(num_classes, batch_size, merge_repeated), + beam_width_(beam_width), + leaves_(beam_width), + beam_scorer_(scorer) { + Reset(); + } + + ~CTCBeamSearchDecoder() override {} + + // Run the hibernating beam search algorithm on the given input. + bool Decode(const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, + CTCDecoder::ScoreOutput* scores) override; + + // Calculate the next step of the beam search and update the internal state. + template + void Step(const Vector& log_input_t); + + template + float GetTopK(const int K, const Vector& input, + std::vector* top_k_logits, + std::vector* top_k_indices); + + // Retrieve the beam scorer instance used during decoding. + BaseBeamScorer* GetBeamScorer() const { return beam_scorer_; } + + // Set label selection parameters for faster decoding. + // See comments for label_selection_size_ and label_selection_margin_. + void SetLabelSelectionParameters(int label_selection_size, + float label_selection_margin) { + label_selection_size_ = label_selection_size; + label_selection_margin_ = label_selection_margin; + } + + // Reset the beam search + void Reset(); + + // Extract the top n paths at current time step + bool TopPaths(int n, std::vector>* paths, + std::vector* log_probs, bool merge_repeated) const; + + private: + int beam_width_; + + // Label selection is designed to avoid possibly very expensive scorer calls, + // by pruning the hypotheses based on the input alone. + // Label selection size controls how many items in each beam are passed + // through to the beam scorer. Only items with top N input scores are + // considered. + // Label selection margin controls the difference between minimal input score + // (versus the best scoring label) for an item to be passed to the beam + // scorer. This margin is expressed in terms of log-probability. + // Default is to do no label selection. + // For more detail: https://research.google.com/pubs/pub44823.html + int label_selection_size_ = 0; // zero means unlimited + float label_selection_margin_ = -1; // -1 means unlimited. + + gtl::TopN leaves_; + std::unique_ptr beam_root_; + BaseBeamScorer* beam_scorer_; + + CTCBeamSearchDecoder(const CTCBeamSearchDecoder&) = delete; + void operator=(const CTCBeamSearchDecoder&) = delete; +}; + +template +bool CTCBeamSearchDecoder::Decode( + const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, ScoreOutput* scores) { + // Storage for top paths. + std::vector> beams; + std::vector beam_log_probabilities; + int top_n = output->size(); + if (std::any_of(output->begin(), output->end(), + [this](const CTCDecoder::Output& output) -> bool { + return output.size() < this->batch_size_; + })) { + return false; + } + if (scores->rows() < batch_size_ || scores->cols() < top_n) { + return false; + } + + for (int b = 0; b < batch_size_; ++b) { + int seq_len_b = seq_len[b]; + Reset(); + + for (int t = 0; t < seq_len_b; ++t) { + // Pass log-probabilities for this example + time. + Step(input[t].row(b)); + } // for (int t... + + // O(n * log(n)) + std::unique_ptr> branches(leaves_.Extract()); + leaves_.Reset(); + for (int i = 0; i < branches->size(); ++i) { + BeamEntry* entry = (*branches)[i]; + beam_scorer_->ExpandStateEnd(&entry->state); + entry->newp.total += + beam_scorer_->GetStateEndExpansionScore(entry->state); + leaves_.push(entry); + } + + bool status = + TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_); + if (!status) { + return status; + } + + TFLITE_DCHECK_EQ(top_n, beam_log_probabilities.size()); + TFLITE_DCHECK_EQ(beams.size(), beam_log_probabilities.size()); + + for (int i = 0; i < top_n; ++i) { + // Copy output to the correct beam + batch + (*output)[i][b].swap(beams[i]); + (*scores)(b, i) = -beam_log_probabilities[i]; + } + } // for (int b... + return true; +} + +template +template +float CTCBeamSearchDecoder::GetTopK( + const int K, const Vector& input, std::vector* top_k_logits, + std::vector* top_k_indices) { + // Find Top K choices, complexity nk in worst case. The array input is read + // just once. + TFLITE_DCHECK_EQ(num_classes_, input.size()); + top_k_logits->clear(); + top_k_indices->clear(); + top_k_logits->resize(K, -INFINITY); + top_k_indices->resize(K, -1); + for (int j = 0; j < num_classes_ - 1; ++j) { + const float logit = input(j); + if (logit > (*top_k_logits)[K - 1]) { + int k = K - 1; + while (k > 0 && logit > (*top_k_logits)[k - 1]) { + (*top_k_logits)[k] = (*top_k_logits)[k - 1]; + (*top_k_indices)[k] = (*top_k_indices)[k - 1]; + k--; + } + (*top_k_logits)[k] = logit; + (*top_k_indices)[k] = j; + } + } + // Return max value which is in 0th index or blank character logit + return std::max((*top_k_logits)[0], input(num_classes_ - 1)); +} + +template +template +void CTCBeamSearchDecoder::Step( + const Vector& raw_input) { + std::vector top_k_logits; + std::vector top_k_indices; + const bool top_k = + (label_selection_size_ > 0 && label_selection_size_ < raw_input.size()); + // Number of character classes to consider in each step. + const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1); + // Get max coefficient and remove it from raw_input later. + float max_coeff; + if (top_k) { + max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits, + &top_k_indices); + } else { + max_coeff = raw_input.maxCoeff(); + } + const float label_selection_input_min = + (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) + : -std::numeric_limits::infinity(); + + // Extract the beams sorted in decreasing new probability + TFLITE_DCHECK_EQ(num_classes_, raw_input.size()); + + std::unique_ptr> branches(leaves_.Extract()); + leaves_.Reset(); + + for (BeamEntry* b : *branches) { + // P(.. @ t) becomes the new P(.. @ t-1) + b->oldp = b->newp; + } + + for (BeamEntry* b : *branches) { + if (b->parent != nullptr) { // if not the root + if (b->parent->Active()) { + // If last two sequence characters are identical: + // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5) + // + Pblank(l=ac @ t=5)) + // else: + // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5) + // + P(l=ab @ t=5)) + float previous = (b->label == b->parent->label) ? b->parent->oldp.blank + : b->parent->oldp.total; + b->newp.label = + LogSumExp(b->newp.label, + beam_scorer_->GetStateExpansionScore(b->state, previous)); + } + // Plabel(l=abc @ t=6) *= P(c @ 6) + b->newp.label += raw_input(b->label) - max_coeff; + } + // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) + b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; + // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) + b->newp.total = LogSumExp(b->newp.blank, b->newp.label); + + // Push the entry back to the top paths list. + // Note, this will always fill leaves back up in sorted order. + leaves_.push(b); + } + + // we need to resort branches in descending oldp order. + + // branches is in descending oldp order because it was + // originally in descending newp order and we copied newp to oldp. + + // Grow new leaves + for (BeamEntry* b : *branches) { + // A new leaf (represented by its BeamProbability) is a candidate + // iff its total probability is nonzero and either the beam list + // isn't full, or the lowest probability entry in the beam has a + // lower probability than the leaf. + auto is_candidate = [this](const BeamProbability& prob) { + return (prob.total > kLogZero && + (leaves_.size() < beam_width_ || + prob.total > leaves_.peek_bottom()->newp.total)); + }; + + if (!is_candidate(b->oldp)) { + continue; + } + + for (int ind = 0; ind < max_classes; ind++) { + const int label = top_k ? top_k_indices[ind] : ind; + const float logit = top_k ? top_k_logits[ind] : raw_input(ind); + // Perform label selection: if input for this label looks very + // unpromising, never evaluate it with a scorer. + if (logit < label_selection_input_min) { + continue; + } + BeamEntry& c = b->GetChild(label); + if (!c.Active()) { + // Pblank(l=abcd @ t=6) = 0 + c.newp.blank = kLogZero; + // If new child label is identical to beam label: + // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6) + // Otherwise: + // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) + beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); + float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; + c.newp.label = logit - max_coeff + + beam_scorer_->GetStateExpansionScore(c.state, previous); + // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) + c.newp.total = c.newp.label; + + if (is_candidate(c.newp)) { + // Before adding the new node to the beam, check if the beam + // is already at maximum width. + if (leaves_.size() == beam_width_) { + // Bottom is no longer in the beam search. Reset + // its probability; signal it's no longer in the beam search. + BeamEntry* bottom = leaves_.peek_bottom(); + bottom->newp.Reset(); + } + leaves_.push(&c); + } else { + // Deactivate child. + c.oldp.Reset(); + c.newp.Reset(); + } + } + } + } // for (BeamEntry* b... +} + +template +void CTCBeamSearchDecoder::Reset() { + leaves_.Reset(); + + // This beam root, and all of its children, will be in memory until + // the next reset. + beam_root_.reset(new BeamRoot(nullptr, -1)); + beam_root_->RootEntry()->newp.total = 0.0; // ln(1) + beam_root_->RootEntry()->newp.blank = 0.0; // ln(1) + + // Add the root as the initial leaf. + leaves_.push(beam_root_->RootEntry()); + + // Call initialize state on the root object. + beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); +} + +template +bool CTCBeamSearchDecoder::TopPaths( + int n, std::vector>* paths, std::vector* log_probs, + bool merge_repeated) const { + TFLITE_DCHECK(paths); + TFLITE_DCHECK(log_probs); + paths->clear(); + log_probs->clear(); + if (n > beam_width_) { + return false; + } + if (n > leaves_.size()) { + return false; + } + + gtl::TopN top_branches(n); + + // O(beam_width_ * log(n)), space complexity is O(n) + for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) { + top_branches.push(*it); + } + // O(n * log(n)) + std::unique_ptr> branches(top_branches.Extract()); + + for (int i = 0; i < n; ++i) { + BeamEntry* e((*branches)[i]); + paths->push_back(e->LabelSeq(merge_repeated)); + log_probs->push_back(e->newp.total); + } + return true; +} + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc new file mode 100644 index 0000000000000000000000000000000000000000..834d1ebd666db2be46394166edadf2a166d958aa --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -0,0 +1,247 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace experimental { +namespace ctc_beam_search_decoder { + +constexpr int kInputsTensor = 0; +constexpr int kSequenceLengthTensor = 1; + +typedef struct { + int beam_width; + int top_paths; + bool merge_repeated; +} CTCBeamSearchDecoderParams; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_CHECK(buffer != nullptr); + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + + CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams; + option->beam_width = m["beam_width"].AsInt32(); + option->top_paths = m["top_paths"].AsInt32(); + option->merge_repeated = m["merge_repeated"].AsBool(); + + return option; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const CTCBeamSearchDecoderParams* option = + reinterpret_cast(node->user_data); + const int top_paths = option->top_paths; + TF_LITE_ENSURE(context, option->beam_width >= top_paths); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + // The outputs should be top_paths * 3 + 1. + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1); + + const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3); + // TensorFlow only supports float. + TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32); + const int batch_size = SizeOfDimension(inputs, 1); + + const TfLiteTensor* sequence_length = + GetInput(context, node, kSequenceLengthTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1); + TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size); + // TensorFlow only supports int32. + TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32); + + // Resize decoded outputs. + // Do not resize indices & values cause we don't know the values yet. + for (int i = 0; i < top_paths; ++i) { + TfLiteTensor* indices = GetOutput(context, node, i); + SetTensorToDynamic(indices); + TfLiteTensor* values = GetOutput(context, node, i + top_paths); + SetTensorToDynamic(values); + TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths); + SetTensorToDynamic(output_shape); + } + + // Resize log probability outputs. + TfLiteTensor* log_probability_output = + GetOutput(context, node, top_paths * 3); + TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2); + log_probability_output_shape_array->data[0] = batch_size; + log_probability_output_shape_array->data[1] = top_paths; + return context->ResizeTensor(context, log_probability_output, + log_probability_output_shape_array); +} + +TfLiteStatus Resize(TfLiteContext* context, + std::initializer_list output_shape, + TfLiteTensor* output) { + const int dimensions = output_shape.size(); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions); + int i = 0; + for (const int v : output_shape) { + output_shape_array->data[i++] = v; + } + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus StoreAllDecodedSequences( + TfLiteContext* context, + const std::vector>>& sequences, + TfLiteNode* node, int top_paths) { + const int32_t batch_size = sequences.size(); + std::vector num_entries(top_paths, 0); + + // Calculate num_entries per path + for (const auto& batch_s : sequences) { + TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths); + for (int p = 0; p < top_paths; ++p) { + num_entries[p] += batch_s[p].size(); + } + } + + for (int p = 0; p < top_paths; ++p) { + const int32_t p_num = num_entries[p]; + + // Resize the decoded outputs. + TfLiteTensor* indices = GetOutput(context, node, p); + TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices)); + + TfLiteTensor* values = GetOutput(context, node, p + top_paths); + TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values)); + + TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths); + TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape)); + + int32_t max_decoded = 0; + int32_t offset = 0; + + int32_t* indices_data = GetTensorData(indices); + int32_t* values_data = GetTensorData(values); + int32_t* decoded_shape_data = GetTensorData(decoded_shape); + for (int b = 0; b < batch_size; ++b) { + auto& p_batch = sequences[b][p]; + int32_t num_decoded = p_batch.size(); + max_decoded = std::max(max_decoded, num_decoded); + + std::copy_n(p_batch.begin(), num_decoded, values_data + offset); + for (int32_t t = 0; t < num_decoded; ++t, ++offset) { + indices_data[offset * 2] = b; + indices_data[offset * 2 + 1] = t; + } + } + + decoded_shape_data[0] = batch_size; + decoded_shape_data[1] = max_decoded; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor); + const TfLiteTensor* sequence_length = + GetInput(context, node, kSequenceLengthTensor); + const CTCBeamSearchDecoderParams* option = + reinterpret_cast(node->user_data); + + const int max_time = SizeOfDimension(inputs, 0); + const int batch_size = SizeOfDimension(inputs, 1); + const int num_classes = SizeOfDimension(inputs, 2); + + const int beam_width = option->beam_width; + const int top_paths = option->top_paths; + const bool merge_repeated = option->merge_repeated; + + // Validate sequence length is less or equal than max time. + for (int i = 0; i < batch_size; ++i) { + TF_LITE_ENSURE(context, + max_time >= GetTensorData(sequence_length)[i]); + } + + // The following logic is implemented like + // tensorflow/core/kernels/ctc_decoder_ops.cc + std::vector::UnalignedConstMatrix> input_list_t; + + for (std::size_t t = 0; t < max_time; ++t) { + input_list_t.emplace_back( + GetTensorData(inputs) + t * batch_size * num_classes, batch_size, + num_classes); + } + + ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer + beam_scorer; + ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search( + num_classes, beam_width, &beam_scorer, 1 /* batch_size */, + merge_repeated); + + // Allocate temporary memory for holding chip operation data. + float* input_chip_t_data = + static_cast(malloc(num_classes * sizeof(float))); + Eigen::array dims; + dims[0] = num_classes; + optimized_ops::TTypes::Flat input_chip_t(input_chip_t_data, dims); + + std::vector>> best_paths(batch_size); + std::vector log_probs; + + TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths); + float* log_probabilities_output = GetTensorData(log_probabilities); + + // Assumption: the blank index is num_classes - 1 + for (int b = 0; b < batch_size; ++b) { + auto& best_paths_b = best_paths[b]; + best_paths_b.resize(top_paths); + for (int t = 0; t < GetTensorData(sequence_length)[b]; ++t) { + input_chip_t = input_list_t[t].chip(b, 0); + auto input_bi = + Eigen::Map(input_chip_t.data(), num_classes); + beam_search.Step(input_bi); + } + TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b, + &log_probs, merge_repeated)); + beam_search.Reset(); + + // Fill in log_probabilities output. + for (int bp = 0; bp < top_paths; ++bp) { + log_probabilities_output[b * top_paths + bp] = log_probs[bp]; + } + } + + free(input_chip_t_data); + return StoreAllDecodedSequences(context, best_paths, node, top_paths); +} + +} // namespace ctc_beam_search_decoder + +TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() { + static TfLiteRegistration r = { + ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free, + ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval}; + return &r; +} + +} // namespace experimental +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d1e6a562f00905d1db7f7e055ac1c6b1cc34f9e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace experimental { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class CTCBeamSearchDecoderOpModel : public SingleOpModel { + public: + CTCBeamSearchDecoderOpModel(std::initializer_list input_shape, + std::initializer_list sequence_length_shape, + int beam_width, int top_paths, + bool merge_repeated) { + inputs_ = AddInput(TensorType_FLOAT32); + sequence_length_ = AddInput(TensorType_INT32); + + for (int i = 0; i < top_paths * 3; ++i) { + outputs_.push_back(AddOutput(TensorType_INT32)); + } + outputs_.push_back(AddOutput(TensorType_FLOAT32)); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("beam_width", beam_width); + fbb.Int("top_paths", top_paths); + fbb.Bool("merge_repeated", merge_repeated); + }); + fbb.Finish(); + SetCustomOp("CTCBeamSearchDecoder", fbb.GetBuffer(), + Register_CTC_BEAM_SEARCH_DECODER); + BuildInterpreter({input_shape, sequence_length_shape}); + } + + int inputs() { return inputs_; } + + int sequence_length() { return sequence_length_; } + + std::vector> GetDecodedOutpus() { + std::vector> outputs; + for (int i = 0; i < outputs_.size() - 1; ++i) { + outputs.push_back(ExtractVector(outputs_[i])); + } + return outputs; + } + + std::vector GetLogProbabilitiesOutput() { + return ExtractVector(outputs_[outputs_.size() - 1]); + } + + std::vector> GetOutputShapes() { + std::vector> output_shapes; + for (const int output : outputs_) { + output_shapes.push_back(GetTensorShape(output)); + } + return output_shapes; + } + + private: + int inputs_; + int sequence_length_; + std::vector outputs_; +}; + +TEST(CTCBeamSearchTest, SimpleTest) { + CTCBeamSearchDecoderOpModel m({2, 1, 2}, {1}, 1, 1, true); + m.PopulateTensor(m.inputs(), + {-0.50922557, -1.35512652, -2.55445064, -1.58419356}); + m.PopulateTensor(m.sequence_length(), {2}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(1, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(1)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(1, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0.32134813}))); +} + +TEST(CTCBeamSearchTest, MultiBatchTest) { + CTCBeamSearchDecoderOpModel m({3, 3, 3}, {3}, 1, 1, true); + m.PopulateTensor( + m.inputs(), + {-0.63649208, -0.00487571, -0.04249819, -0.67754697, -1.0341399, + -2.14717721, -0.77686821, -3.41973774, -0.05151402, -0.21482619, + -0.57411168, -1.45039917, -0.73769373, -2.10941739, -0.44818325, + -0.25287673, -2.80057302, -0.54748312, -0.73334867, -0.86537719, + -0.2065197, -0.18725838, -1.42770405, -0.86051965, -1.61642301, + -2.07275114, -0.9201845}); + m.PopulateTensor(m.sequence_length(), {3, 3, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(4, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(4)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(3, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 2, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2)); + // Check log probabilities output. + EXPECT_THAT( + m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572}))); +} + +TEST(CTCBeamSearchTest, MultiPathsTest) { + CTCBeamSearchDecoderOpModel m({3, 2, 5}, {2}, 3, 2, true); + m.PopulateTensor( + m.inputs(), + {-2.206851, -0.09542714, -0.2393415, -3.81866197, -0.27241158, + -0.20371124, -0.68236623, -1.1397166, -0.17422639, -1.85224048, + -0.9406037, -0.32544678, -0.21846784, -0.38377237, -0.33498676, + -0.10139782, -0.51886883, -0.21678554, -0.15267063, -1.91164412, + -0.31328673, -0.27462716, -0.65975336, -1.53671973, -2.76554225, + -0.23920634, -1.2370502, -4.98751576, -3.12995717, -0.43129368}); + m.PopulateTensor(m.sequence_length(), {3, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 7); + EXPECT_THAT(output_shapes[0], ElementsAre(4, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(3, 2)); + EXPECT_THAT(output_shapes[2], ElementsAre(4)); + EXPECT_THAT(output_shapes[3], ElementsAre(3)); + EXPECT_THAT(output_shapes[4], ElementsAre(2)); + EXPECT_THAT(output_shapes[5], ElementsAre(2)); + EXPECT_THAT(output_shapes[6], ElementsAre(2, 2)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 6); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 1, 1)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(0, 0, 0, 1, 1, 0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 2, 3, 0)); + EXPECT_THAT(decoded_outputs[3], ElementsAre(2, 1, 0)); + EXPECT_THAT(decoded_outputs[4], ElementsAre(2, 2)); + EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear( + {0.91318405, 0.9060272, 1.0780245, 0.64358956}))); +} + +TEST(CTCBeamSearchTest, NonEqualSequencesTest) { + CTCBeamSearchDecoderOpModel m({3, 3, 4}, {3}, 3, 1, true); + m.PopulateTensor( + m.inputs(), + {-1.26658163, -0.25760023, -0.03917975, -0.63772235, -0.03794756, + -0.45063099, -0.27706473, -0.01569179, -0.59940385, -0.35700127, + -0.48920721, -1.42635476, -1.3462478, -0.02565498, -0.30179568, + -0.6491698, -0.55017719, -2.92291466, -0.92522973, -0.47592022, + -0.07099135, -0.31575624, -0.86345281, -0.36017021, -0.79208612, + -1.75306124, -0.65089224, -0.00912786, -0.42915003, -1.72606203, + -1.66337589, -0.70800793, -2.52272352, -0.67329562, -2.49145522, + -0.49786342}); + m.PopulateTensor(m.sequence_length(), {1, 2, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(3, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(3, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 1, 0, 2, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(2, 0, 1)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005}))); +} + +} // namespace +} // namespace experimental +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h new file mode 100644 index 0000000000000000000000000000000000000000..596ad4a5f7264ae24caa5592d10c09c256629b06 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_decoder.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ + +#include +#include + +#include "third_party/eigen3/Eigen/Core" + +namespace tflite { +namespace experimental { +namespace ctc { + +// The CTCDecoder is an abstract interface to be implemented when providing a +// decoding method on the timestep output of a RNN trained with CTC loss. +// +// The two types of decoding available are: +// - greedy path, through the CTCGreedyDecoder +// - beam search, through the CTCBeamSearchDecoder +class CTCDecoder { + public: + typedef Eigen::Map SequenceLength; + typedef Eigen::Map Input; + typedef std::vector> Output; + typedef Eigen::Map ScoreOutput; + + CTCDecoder(int num_classes, int batch_size, bool merge_repeated) + : num_classes_(num_classes), + blank_index_(num_classes - 1), + batch_size_(batch_size), + merge_repeated_(merge_repeated) {} + + virtual ~CTCDecoder() {} + + // Dimensionality of the input/output is expected to be: + // - seq_len[b] - b = 0 to batch_size_ + // - input[t].rows(b) - t = 0 to timesteps; b = 0 t batch_size_ + // - output.size() specifies the number of beams to be returned. + // - scores(b, i) - b = 0 to batch_size; i = 0 to output.size() + virtual bool Decode(const SequenceLength& seq_len, + const std::vector& input, + std::vector* output, ScoreOutput* scores) = 0; + + int batch_size() { return batch_size_; } + int num_classes() { return num_classes_; } + + protected: + int num_classes_; + int blank_index_; + int batch_size_; + bool merge_repeated_; +}; + +// CTCGreedyDecoder is an implementation of the simple best path decoding +// algorithm, selecting at each timestep the most likely class at each timestep. +class CTCGreedyDecoder : public CTCDecoder { + public: + CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated) + : CTCDecoder(num_classes, batch_size, merge_repeated) {} + + bool Decode(const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, + CTCDecoder::ScoreOutput* scores) override { + if (output->empty() || (*output)[0].size() < batch_size_) { + return false; + } + if (scores->rows() < batch_size_ || scores->cols() == 0) { + return false; + } + // For each batch entry, identify the transitions + for (int b = 0; b < batch_size_; ++b) { + int seq_len_b = seq_len[b]; + // Only writing to beam 0 + std::vector& output_b = (*output)[0][b]; + + int prev_class_ix = -1; + (*scores)(b, 0) = 0; + for (int t = 0; t < seq_len_b; ++t) { + auto row = input[t].row(b); + int max_class_ix; + (*scores)(b, 0) += -row.maxCoeff(&max_class_ix); + if (max_class_ix != blank_index_ && + !(merge_repeated_ && max_class_ix == prev_class_ix)) { + output_b.push_back(max_class_ix); + } + prev_class_ix = max_class_ix; + } + } + return true; + } +}; + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0bae732533716ac047a55ea31633c8ed51253fe0 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_loss_util.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ + +#include +#include + +namespace tflite { +namespace experimental { +namespace ctc { + +const float kLogZero = -std::numeric_limits::infinity(); + +// Add logarithmic probabilities using: +// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a))) +// The two inputs are assumed to be log probabilities. +// (GravesTh) Eq. 7.18 +inline float LogSumExp(float log_prob_1, float log_prob_2) { + // Always have 'b' be the smaller number to avoid the exponential from + // blowing up. + if (log_prob_1 == kLogZero && log_prob_2 == kLogZero) { + return kLogZero; + } else { + return (log_prob_1 > log_prob_2) + ? log_prob_1 + log1pf(expf(log_prob_2 - log_prob_1)) + : log_prob_2 + log1pf(expf(log_prob_1 - log_prob_2)); + } +} + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/top_n.h b/tensorflow/contrib/lite/experimental/kernels/top_n.h new file mode 100644 index 0000000000000000000000000000000000000000..cd2a2f1c80276d4659ccd2f8f05af3af030acb90 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/top_n.h @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This simple class finds the top n elements of an incrementally provided set +// of elements which you push one at a time. If the number of elements exceeds +// n, the lowest elements are incrementally dropped. At the end you get +// a vector of the top elements sorted in descending order (through Extract() or +// ExtractNondestructive()), or a vector of the top elements but not sorted +// (through ExtractUnsorted() or ExtractUnsortedNondestructive()). +// +// The value n is specified in the constructor. If there are p elements pushed +// altogether: +// The total storage requirements are O(min(n, p)) elements +// The running time is O(p * log(min(n, p))) comparisons +// If n is a constant, the total storage required is a constant and the running +// time is linear in p. +// +// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p) +// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements, +// discarding the lowest n elements whenever the buffer is full using a linear- +// time median algorithm. This may have better performance when the input +// sequence is partially sorted. +// +// NOTE(zhifengc): This class should be redesigned to avoid reallocating a +// vector for each Extract. + +// Copied from tensorflow/core/lib/gtl/top_n.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace gtl { + +// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate, +// not the more commonly used "less" predicate. +// +// If you use a "less" predicate here, the TopN will pick out the bottom N +// elements out of the ones passed to it, and it will return them sorted in +// ascending order. +// +// TopN is rule-of-zero copyable and movable if its members are. +template > +class TopN { + public: + // The TopN is in one of the three states: + // + // o UNORDERED: this is the state an instance is originally in, + // where the elements are completely orderless. + // + // o BOTTOM_KNOWN: in this state, we keep the invariant that there + // is at least one element in it, and the lowest element is at + // position 0. The elements in other positions remain + // unsorted. This state is reached if the state was originally + // UNORDERED and a peek_bottom() function call is invoked. + // + // o HEAP_SORTED: in this state, the array is kept as a heap and + // there are exactly (limit_+1) elements in the array. This + // state is reached when at least (limit_+1) elements are + // pushed in. + // + // The state transition graph is at follows: + // + // peek_bottom() (limit_+1) elements + // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED + // | ^ + // | (limit_+1) elements | + // +-----------------------------------------------------------+ + + enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED }; + using UnsortedIterator = typename std::vector::const_iterator; + + // 'limit' is the maximum number of top results to return. + explicit TopN(size_t limit) : TopN(limit, Cmp()) {} + TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {} + + size_t limit() const { return limit_; } + + // Number of elements currently held by this TopN object. This + // will be no greater than 'limit' passed to the constructor. + size_t size() const { return std::min(elements_.size(), limit_); } + + bool empty() const { return size() == 0; } + + // If you know how many elements you will push at the time you create the + // TopN object, you can call reserve to preallocate the memory that TopN + // will need to process all 'n' pushes. Calling this method is optional. + void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); } + + // Push 'v'. If the maximum number of elements was exceeded, drop the + // lowest element and return it in 'dropped' (if given). If the maximum is not + // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or + // nullptr, in which case it is not filled in. + // Requires: T is CopyAssignable, Swappable + void push(const T &v) { push(v, nullptr); } + void push(const T &v, T *dropped) { PushInternal(v, dropped); } + + // Move overloads of push. + // Requires: T is MoveAssignable, Swappable + void push(T &&v) { // NOLINT(build/c++11) + push(std::move(v), nullptr); + } + void push(T &&v, T *dropped) { // NOLINT(build/c++11) + PushInternal(std::move(v), dropped); + } + + // Peeks the bottom result without calling Extract() + const T &peek_bottom(); + + // Extract the elements as a vector sorted in descending order. The caller + // assumes ownership of the vector and must delete it when done. This is a + // destructive operation. The only method that can be called immediately + // after Extract() is Reset(). + std::vector *Extract(); + + // Similar to Extract(), but makes no guarantees the elements are in sorted + // order. As with Extract(), the caller assumes ownership of the vector and + // must delete it when done. This is a destructive operation. The only + // method that can be called immediately after ExtractUnsorted() is Reset(). + std::vector *ExtractUnsorted(); + + // A non-destructive version of Extract(). Copy the elements in a new vector + // sorted in descending order and return it. The caller assumes ownership of + // the new vector and must delete it when done. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + std::vector *ExtractNondestructive() const; + + // A non-destructive version of Extract(). Copy the elements to a given + // vector sorted in descending order. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractNondestructive(std::vector *output) const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements in a new + // vector and return it, with no guarantees the elements are in sorted order. + // The caller assumes ownership of the new vector and must delete it when + // done. After calling ExtractUnsortedNondestructive(), the caller can + // continue to push() new elements. + std::vector *ExtractUnsortedNondestructive() const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements into + // a given vector, with no guarantees the elements are in sorted order. + // After calling ExtractUnsortedNondestructive(), the caller can continue + // to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractUnsortedNondestructive(std::vector *output) const; + + // Return an iterator to the beginning (end) of the container, + // with no guarantees about the order of iteration. These iterators are + // invalidated by mutation of the data structure. + UnsortedIterator unsorted_begin() const { return elements_.begin(); } + UnsortedIterator unsorted_end() const { return elements_.begin() + size(); } + + // Accessor for comparator template argument. + Cmp *comparator() { return &cmp_; } + + // This removes all elements. If Extract() or ExtractUnsorted() have been + // called, this will put it back in an empty but useable state. + void Reset(); + + private: + template + void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11) + + // elements_ can be in one of two states: + // elements_.size() <= limit_: elements_ is an unsorted vector of elements + // pushed so far. + // elements_.size() > limit_: The last element of elements_ is unused; + // the other elements of elements_ are an stl heap whose size is exactly + // limit_. In this case elements_.size() is exactly one greater than + // limit_, but don't use "elements_.size() == limit_ + 1" to check for + // that because you'll get a false positive if limit_ == size_t(-1). + std::vector elements_; + size_t limit_; // Maximum number of elements to find + Cmp cmp_; // Greater-than comparison function + State state_ = UNORDERED; +}; + +// ---------------------------------------------------------------------- +// Implementations of non-inline functions + +template +template +void TopN::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11) + if (limit_ == 0) { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + return; + } + if (state_ != HEAP_SORTED) { + elements_.push_back(std::forward(v)); // NOLINT(build/c++11) + if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) { + // Easy case: we just pushed the new element back + } else { + // To maintain the BOTTOM_KNOWN state, we need to make sure that + // the element at position 0 is always the smallest. So we put + // the new element at position 0 and push the original bottom + // element in the back. + // Warning: this code is subtle. + using std::swap; + swap(elements_.front(), elements_.back()); + } + if (elements_.size() == limit_ + 1) { + // Transition from unsorted vector to a heap. + std::make_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + state_ = HEAP_SORTED; + } + } else { + // Only insert the new element if it is greater than the least element. + if (cmp_(v, elements_.front())) { + elements_.back() = std::forward(v); // NOLINT(build/c++11) + std::push_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + } else { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + } + } +} + +template +const T &TopN::peek_bottom() { + TFLITE_DCHECK(!empty()); + if (state_ == UNORDERED) { + // We need to do a linear scan to find out the bottom element + int min_candidate = 0; + for (size_t i = 1; i < elements_.size(); ++i) { + if (cmp_(elements_[min_candidate], elements_[i])) { + min_candidate = i; + } + } + // By swapping the element at position 0 and the minimal + // element, we transition to the BOTTOM_KNOWN state + if (min_candidate != 0) { + using std::swap; + swap(elements_[0], elements_[min_candidate]); + } + state_ = BOTTOM_KNOWN; + } + return elements_.front(); +} + +template +std::vector *TopN::Extract() { + auto out = new std::vector; + out->swap(elements_); + if (state_ != HEAP_SORTED) { + std::sort(out->begin(), out->end(), cmp_); + } else { + out->pop_back(); + std::sort_heap(out->begin(), out->end(), cmp_); + } + return out; +} + +template +std::vector *TopN::ExtractUnsorted() { + auto out = new std::vector; + out->swap(elements_); + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + out->pop_back(); + } + return out; +} + +template +std::vector *TopN::ExtractNondestructive() const { + auto out = new std::vector; + ExtractNondestructive(out); + return out; +} + +template +void TopN::ExtractNondestructive(std::vector *output) const { + TFLITE_DCHECK(output); + *output = elements_; + if (state_ != HEAP_SORTED) { + std::sort(output->begin(), output->end(), cmp_); + } else { + output->pop_back(); + std::sort_heap(output->begin(), output->end(), cmp_); + } +} + +template +std::vector *TopN::ExtractUnsortedNondestructive() const { + auto elements = new std::vector; + ExtractUnsortedNondestructive(elements); + return elements; +} + +template +void TopN::ExtractUnsortedNondestructive(std::vector *output) const { + TFLITE_DCHECK(output); + *output = elements_; + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + output->pop_back(); + } +} + +template +void TopN::Reset() { + elements_.clear(); + state_ = UNORDERED; +} + +} // namespace gtl +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index 2296f5a06449100b0b107f917b3d09f89e25d5bd..d979353bb3550fe53d86b2e6c76702a3970b01fe 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -136,3 +136,39 @@ operations instead of a single operator. 6. Use TF_LITE_ENSURE(context, condition) to check for a specific condition. Your code must not leave memory hanging when TF_LITE_ENSURE is done, i.e., these should be done before any resources are allocated that will leak. + +## Special TF Graph Attributes + +When Toco convertes a TF graph into TFLite format, it makes some assumption +about custom operations that might be not correct. In this case, the generated +graph can be not executable. + +It is possible to add aditional information about your custom op output to TF +graph before it is converted. The following attributes are supported: + +- **_output_quantized** a boolean attribute, true if the operation outputs are + quantized +- **_output_types** a list of types for output tensors +- **_output_shapes** a list of shapes for output tensors + +### Setting the Attributes + +This is an example how the attributes can be set: + +```python +frozen_graph_def = tf.graph_util.convert_variables_to_constants(...) +for node in frozen_graph_def.node: + if node.op == 'sin': + node.attr['_output_types'].list.type.extend([ + types_pb2.DT_FLOAT, + ]) + node.attr['_output_shapes'].list.shape.extend([ + tf.TensorShape([10]), + ]) + node.attr['_output_quantized'].b = False +tflite_model = tf.contrib.lite.toco_convert( + frozen_graph_def,...) +``` + +**Note:** After the attributes are set, the graph can not be executed by +Tensorflow, therefore it should be done just before the conversion. diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index 3292aece0e76244a61613b514457edf479858fdb..4ceb9a53dc0967ab6320a1bfdb1ddb859482c5dd 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -42,22 +42,22 @@ single thread large core. Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance ------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: -Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.7% | 65.8% | 3.7 ms -Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 41.9% | 69.1% | 5.5 ms -Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.3% | 71.9% | 7.9 ms -Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 46.4% | 73.8% | 10.4 ms -Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.1% | 78.9% | 8.8 ms -Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.6% | 81.3% | 13.0 ms -Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 59.1% | 83.2% | 18.3 ms -Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.0% | 84.5% | 24.7 ms -Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 52.5% | 82.8% | 16.2 ms -Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.6% | 85.5% | 24.3 ms -Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 61.1% | 87.1% | 33.8 ms -Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.7% | 88.1% | 45.4 ms -Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 62.7% | 85.5% | 24.9 ms -Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 66.6% | 87.7% | 37.4 ms -Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.9% | 51.9 ms -Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.3% | 89.5% | 70.2 ms +Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms +Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.4% | 68.5% | 5.5 ms +Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms +Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.0% | 72.8% | 10.4 ms +Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.5% | 77.7% | 8.8 ms +Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 80.4% | 13.0 ms +Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.0% | 82.2% | 18.3 ms +Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 60.7% | 83.2% | 24.7 ms +Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.8% | 78.8% | 16.2 ms +Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.3% | 83.8% | 24.3 ms +Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.4% | 33.8 ms +Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.8% | 87.0% | 45.4 ms +Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.4% | 84.2% | 24.9 ms +Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.7% | 37.4 ms +Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.3% | 51.9 ms +Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.1% | 88.9% | 70.2 ms ## Other models diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md index 613e9f97c38942f20d3ca44cdc69e72b35c8608f..5cd0aab44f10de1b76e1acb302fc1ee2711c8d74 100644 --- a/tensorflow/contrib/lite/g3doc/performance.md +++ b/tensorflow/contrib/lite/g3doc/performance.md @@ -39,7 +39,6 @@ Device | CPU_MASK | Pixel 2 | f0 | Pixel xl | 0c | - @@ -50,7 +49,7 @@ Pixel xl | 0c | @@ -61,7 +60,7 @@ Pixel xl | 0c | @@ -134,14 +133,14 @@ modified to set `num_threads` to 1. diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 0e8f4339fc7698294afd3134d4507ca1ad8b7541..aa65ec99887a61df658dd7add7b5cc3b91d81846 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -62,6 +62,7 @@ counterparts: * [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) - *as long as tensors are 2D and axis is the last dimension* * [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k) +* [tf.one_hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) * [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as mode and constant_values are not used* * [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) - @@ -830,6 +831,18 @@ Outputs { } ``` +**LOGICAL_OR** + +``` +Inputs { + 0: a list of tensors. + 1: a list of tensors. +} +Outputs { + 0: A tensor of logical_or output tensors. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 26fecceab028c7fb86edc3f5e0a6d9d1e5556ffd..7a680f5c6400a94a2746d09891e0e39a410404a2 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -26,27 +26,24 @@ limitations under the License. #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" #include "tensorflow/contrib/lite/memory_planner.h" -#ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" -#endif #include "tensorflow/contrib/lite/profiling/profiler.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/util.h" namespace tflite { -#ifdef TFLITE_MCU -class NNAPIDelegate {}; -#endif - namespace { TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node, const TfLiteRegistration& registration, int node_index, const char* message) { - context->ReportError(context, "Node number %d (%s) %s.\n", node_index, - EnumNameBuiltinOperator(static_cast( - registration.builtin_code)), - message); + context->ReportError( + context, "Node number %d (%s) %s.\n", node_index, + registration.custom_name + ? registration.custom_name + : EnumNameBuiltinOperator( + static_cast(registration.builtin_code)), + message); return kTfLiteError; } @@ -131,9 +128,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.SetExternalContext = SetExternalContext; // Invalid to call these these except from TfLiteDelegate - SetForbiddenContextFunction(&context_.GetNodeAndRegistration); - SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); - SetForbiddenContextFunction(&context_.GetExecutionPlan); + SwitchToKernelContext(); // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kTensorsReservedCapacity); @@ -278,8 +273,9 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( int node_index; TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph); - AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors, - nullptr, 0, params, ®istration, &node_index); + TF_LITE_ENSURE_STATUS(AddNodeWithParameters( + subgraph.input_tensors, subgraph.output_tensors, nullptr, 0, params, + ®istration, &node_index)); // Initialize the output tensors's delegate-related fields. for (int tensor_index : subgraph.output_tensors) { @@ -628,7 +624,6 @@ TfLiteStatus Interpreter::Invoke() { } TfLiteStatus status = kTfLiteOk; -#ifndef TFLITE_MCU if (nnapi_delegate_) { if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) { TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); @@ -642,7 +637,6 @@ TfLiteStatus Interpreter::Invoke() { return kTfLiteError; } } -#endif // Invocations are always done in node order. // Note that calling Invoke repeatedly will cause the original memory plan to @@ -900,17 +894,15 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, } void Interpreter::UseNNAPI(bool enable) { -#ifndef TFLITE_MCU // TODO(aselle): This is a workaround for finding if NNAPI exists. // We also need to make sure getLibraryHandle() is renamed to be NNAPI // prefixed. - if (!NNAPIExists()) enable = false; + if (!NNAPIDelegate::IsSupported()) enable = false; if (!enable) { nnapi_delegate_.reset(); } else if (!nnapi_delegate_) { nnapi_delegate_.reset(new NNAPIDelegate); } -#endif } void Interpreter::SetNumThreads(int num_threads) { @@ -924,6 +916,19 @@ void Interpreter::SetNumThreads(int num_threads) { } } +void Interpreter::SwitchToDelegateContext() { + context_.GetNodeAndRegistration = GetNodeAndRegistration; + context_.ReplaceSubgraphsWithDelegateKernels = + ReplaceSubgraphsWithDelegateKernels; + context_.GetExecutionPlan = GetExecutionPlan; +} + +void Interpreter::SwitchToKernelContext() { + SetForbiddenContextFunction(&context_.GetNodeAndRegistration); + SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); + SetForbiddenContextFunction(&context_.GetExecutionPlan); +} + TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, bool allow_dynamic_tensors) { if (!allow_dynamic_tensors) { @@ -950,17 +955,12 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, // TODO(aselle): Consider if it is worth storing pointers to delegates. // Setup additional context interface. - context_.GetNodeAndRegistration = GetNodeAndRegistration; - context_.ReplaceSubgraphsWithDelegateKernels = - ReplaceSubgraphsWithDelegateKernels; - context_.GetExecutionPlan = GetExecutionPlan; + SwitchToDelegateContext(); TfLiteStatus status = delegate->Prepare(&context_, delegate); // Remove additional context info. - SetForbiddenContextFunction(&context_.GetNodeAndRegistration); - SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); - SetForbiddenContextFunction(&context_.GetExecutionPlan); + SwitchToKernelContext(); TF_LITE_ENSURE_OK(&context_, status); diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index bc608e2fcecfd0fa6f03841542f06911e3f88f29..e8301ff5076ec104d09351d081a28f5eb0964bc6 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -111,7 +111,7 @@ class Interpreter { // processing this model will be forwarded to the error_reporter object. // // Note, if error_reporter is nullptr, then a default StderrReporter is - // used. + // used. Ownership of 'error_reporter' remains with the caller. explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); ~Interpreter(); @@ -416,6 +416,13 @@ class Interpreter { private: friend class InterpreterTest; + // Prevent 'context_' from accessing functions that are only available to + // delegated kernels. + void SwitchToKernelContext(); + + // Add delegate-only functions to 'context_'. + void SwitchToDelegateContext(); + // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, @@ -502,6 +509,7 @@ class Interpreter { // Update the execution graph to replace some of the nodes with stub // nodes. Specifically any node index that has `nodes[index]==1` will be // slated for replacement with a delegate kernel specified by registration. + // Ownership of 'nodes_to_replace' and 'delegate' remains with the caller. // WARNING: This is an experimental interface that is subject to change. TfLiteStatus ReplaceSubgraphsWithDelegateKernels( TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, @@ -519,12 +527,13 @@ class Interpreter { TfLiteRegistration** registration); // WARNING: This is an experimental interface that is subject to change. - // Gets an TfLiteIntArray* representing the execution plan. The caller owns - // this memory and must free it with TfLiteIntArrayFree(). + // Gets an TfLiteIntArray* representing the execution plan. The interpreter + // owns this memory and it is only guaranteed to exist during the invocation + // of the delegate prepare. TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan); // WARNING: This is an experimental interface that is subject to change. - // Entry point for C node plugin API to get the execution plan + // Entry point for C node plugin API to get the execution plan. static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 10119903fed448bd44f408efb495831216dc594c..2bf598bad71b87afaa22c1eb95474c49386c122f 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -647,18 +647,6 @@ TEST(BasicInterpreter, AllocateTwice) { ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw); } -struct TestErrorReporter : public ErrorReporter { - int Report(const char* format, va_list args) override { - char buffer[1024]; - int size = vsnprintf(buffer, sizeof(buffer), format, args); - all_reports += buffer; - calls++; - return size; - } - int calls = 0; - std::string all_reports; -}; - TEST(BasicInterpreter, TestNullErrorReporter) { TestErrorReporter reporter; Interpreter interpreter; @@ -668,8 +656,9 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { TestErrorReporter reporter; Interpreter interpreter(&reporter); ASSERT_NE(interpreter.Invoke(), kTfLiteOk); - ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready."); - ASSERT_EQ(reporter.calls, 1); + ASSERT_EQ(reporter.error_messages(), + "Invoke called on model that is not ready."); + ASSERT_EQ(reporter.num_calls(), 1); } TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index e2c1edd9afbac9ca75262e1e41f7cf6334564dec..fdcf00a0a08459d8d669f1def3ae2eb21dbd31c3 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -152,10 +152,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( if (error_reporter == nullptr) return; if (interpreter->AllocateTensors() != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Cannot allocate memory for the interpreter:" - " %s", - error_reporter->CachedErrorMessage()); + throwException( + env, kIllegalStateException, + "Internal error: Unexpected failure when preparing tensor allocations:" + " %s", + error_reporter->CachedErrorMessage()); } } @@ -336,10 +337,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( // allocates memory status = interpreter->AllocateTensors(); if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Cannot allocate memory for the interpreter:" - " %s", - error_reporter->CachedErrorMessage()); + throwException( + env, kIllegalStateException, + "Internal error: Unexpected failure when preparing tensor allocations:" + " %s", + error_reporter->CachedErrorMessage()); return 0; } return reinterpret_cast(interpreter.release()); diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index c224132cae152bd7eab1a6a397303bcb571b0e29..c5586475ec258849948ff6b960abc846e2ea1b3c 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -8,6 +8,19 @@ load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# Suppress warnings that are introduced by Eigen Tensor. +EXTRA_EIGEN_COPTS = select({ + "//tensorflow:ios": [ + "-Wno-error=invalid-partial-specialization", + "-Wno-error=reorder", + ], + "//tensorflow:windows": [ + "/DEIGEN_HAS_C99_MATH", + "/DEIGEN_AVOID_STL_ARRAY", + ], + "//conditions:default": ["-Wno-error=reorder"], +}) + tf_cc_test( name = "optional_tensor_test", size = "small", @@ -49,13 +62,7 @@ cc_library( hdrs = [ "eigen_support.h", ], - copts = tflite_copts() + [ - "-Wno-error=reorder", - ] + select({ - "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], - "//conditions:default": [ - ], - }), + copts = tflite_copts() + EXTRA_EIGEN_COPTS, deps = [ ":op_macros", "//tensorflow/contrib/lite:arena_planner", @@ -170,12 +177,14 @@ cc_library( "hashtable_lookup.cc", "l2norm.cc", "local_response_norm.cc", + "logical.cc", "lsh_projection.cc", "lstm.cc", "maximum_minimum.cc", "mfcc.cc", "mul.cc", "neg.cc", + "one_hot.cc", "pack.cc", "pad.cc", "pooling.cc", @@ -207,14 +216,7 @@ cc_library( "padding.h", "register.h", ], - # Suppress warnings that are introduced by Eigen Tensor. - copts = tflite_copts() + [ - "-Wno-error=reorder", - ] + select({ - "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], - "//conditions:default": [ - ], - }), + copts = tflite_copts() + EXTRA_EIGEN_COPTS, deps = [ ":activation_functor", ":eigen_support", @@ -1171,6 +1173,33 @@ tf_cc_test( ], ) +tf_cc_test( + name = "one_hot_test", + size = "small", + srcs = ["one_hot_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "logical_test", + size = "small", + srcs = ["logical_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 6e13b8c667c5c5188c9e1bc753346f231ae8e1b0..817266a47147980699a348a5c26ed637828e80c6 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -212,25 +212,25 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); const TfLiteTensor* alpha = GetInput(context, node, 1); - output->type = input->type; - // Currently only Float32 is supported // TODO(ycling): Support other data types. TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32); + output->type = input->type; - // Currently, only support 4D `input` and 3D `alpha` with shape - // (1, 1, channels). - // TODO(impjdi): Support other cases where `alpha` is broadcastable - // to `input`. - TF_LITE_ENSURE_EQ(context, input->dims->size, 4); - TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]); + // PRelu (parameteric Relu) shares the same alpha value on "shared axis". + // This means it's always required to "broadcast" alpha values in PRelu. + TfLiteIntArray* output_size = nullptr; + TF_LITE_ENSURE_OK( + context, CalculateShapeForBroadcast(context, input, alpha, &output_size)); - return context->ResizeTensor(context, output, - TfLiteIntArrayCopy(input->dims)); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + // After broadcasting, the output shape should always be the same as the + // input shape. + TF_LITE_ENSURE(context, HaveSameShapes(input, output)); + + return kTfLiteOk; } TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { @@ -524,33 +524,24 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } } +template +T ApplyPrelu(T input, T alpha) { + return input >= 0.0 ? input : input * alpha; +} + TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* alpha = GetInput(context, node, 1); - const TfLiteTensor* output = GetOutput(context, node, 0); - + TfLiteTensor* output = GetOutput(context, node, 0); if (input->type != kTfLiteFloat32) { context->ReportError(context, "Only float32 supported currently, got %d.", input->type); return kTfLiteError; } - TF_LITE_ENSURE_EQ(context, input->dims->size, 4); - const int batches = input->dims->data[0]; - const int height = input->dims->data[1]; - const int width = input->dims->data[2]; - const int channels = input->dims->data[3]; - - TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels); - - const int n = batches * height * width * channels; - for (int i = 0; i < n; ++i) { - const float x = input->data.f[i]; - output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x; - } - + reference_ops::BroadcastBinaryFunction( + GetTensorData(input), GetTensorDims(input), + GetTensorData(alpha), GetTensorDims(alpha), + GetTensorData(output), GetTensorDims(output), ApplyPrelu); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index f678f48fa5bbbcece6c5b87030d951783378d78f..8b4d778332afd5f4b53509bd669a674c63d9f6f9 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -57,6 +57,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +// TODO(ruic): optimize macros below to using template functions. +#define TF_LITE_QUANTIZE_COMPARISON(opname) \ + void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node, \ + const TfLiteTensor* input1, \ + const TfLiteTensor* input2, TfLiteTensor* output, \ + bool requires_broadcast) { \ + if (input1->type == kTfLiteUInt8) { \ + auto input1_offset = -input1->params.zero_point; \ + auto input2_offset = -input2->params.zero_point; \ + const int left_shift = 20; \ + const double twice_max_input_scale = \ + 2 * std::max(input1->params.scale, input2->params.scale); \ + const double real_input1_multiplier = \ + input1->params.scale / twice_max_input_scale; \ + const double real_input2_multiplier = \ + input2->params.scale / twice_max_input_scale; \ + \ + int32 input1_multiplier; \ + int input1_shift; \ + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \ + &input1_multiplier, &input1_shift); \ + int32 input2_multiplier; \ + int input2_shift; \ + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \ + &input2_multiplier, &input2_shift); \ + \ + if (requires_broadcast) { \ + reference_ops::Broadcast##opname( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + input2_offset, input2_multiplier, input2_shift, \ + GetTensorData(output), GetTensorDims(output)); \ + } else { \ + reference_ops::opname( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + input2_offset, input2_multiplier, input2_shift, \ + GetTensorData(output), GetTensorDims(output)); \ + } \ + } \ + } +TF_LITE_QUANTIZE_COMPARISON(Equal); +TF_LITE_QUANTIZE_COMPARISON(NotEqual); +TF_LITE_QUANTIZE_COMPARISON(Greater); +TF_LITE_QUANTIZE_COMPARISON(GreaterEqual); +TF_LITE_QUANTIZE_COMPARISON(Less); +TF_LITE_QUANTIZE_COMPARISON(LessEqual); +#undef TF_LITE_QUANTIZE_COMPARISON + #define TF_LITE_COMPARISON(type, opname, requires_broadcast) \ requires_broadcast \ ? reference_ops::Broadcast##opname( \ @@ -73,7 +124,6 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Equal, requires_broadcast); @@ -84,9 +134,13 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -99,7 +153,6 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); @@ -110,9 +163,13 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedNotEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -124,7 +181,6 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Greater, requires_broadcast); @@ -135,9 +191,13 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedGreater(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -149,7 +209,6 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast); @@ -160,9 +219,13 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedGreaterEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -174,7 +237,6 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Less, requires_broadcast); @@ -185,9 +247,13 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Less, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedLess(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -199,7 +265,6 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, LessEqual, requires_broadcast); @@ -210,9 +275,13 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedLessEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index bb02e1c812fdc40bf515f1f978e9e39b5a16a4ea..67a91c17fd4f25e4a9ea22de5e2a10dc1c17656d 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -35,6 +35,15 @@ class ComparisonOpModel : public SingleOpModel { BuildInterpreter({input1_shape, input2_shape}); } + ComparisonOpModel(const TensorData& input1, const TensorData& input2, + TensorType input_type, BuiltinOperator op) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(TensorType_BOOL); + ConfigureBuiltinOp(op); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + int input1() { return input1_; } int input2() { return input2_; } @@ -354,6 +363,192 @@ TEST(ComparisonsTest, LessEqualBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } +TEST(QuantizedComparisonsTest, EqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, false)); +} + +TEST(QuantizedComparisonsTest, NotEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_NOT_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 7, 0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true)); +} + +TEST(ComparisonsTest, GreaterQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); +} + +TEST(ComparisonsTest, GreaterEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false)); +} + +TEST(ComparisonsTest, LessQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true)); +} + +TEST(ComparisonsTest, LessEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); +} + +TEST(ComparisonsTest, QuantizedEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {2}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, false, false, false, false)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedNotEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_NOT_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {2}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, true, true, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedGreaterWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, false, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedGreaterEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedLessWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, false)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedLessEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, true, false, false)) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index ad211e9c67eed9ca70fcdd51171fdb70bd89b27c..605a20ac3e7c8346db2bcf64e9422132b433b3da 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -57,7 +57,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, t0->dims->size <= 4); TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || + input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || + input_type == kTfLiteInt64); // Output dimensions will match input dimensions, except 'axis', which // will be the sum of inputs @@ -121,6 +123,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION(optimized_ops, float); } break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int32); + } else { + TF_LITE_CONCATENATION(optimized_ops, int32); + } + break; case kTfLiteUInt8: if (kernel_type == kReference) { TF_LITE_CONCATENATION_QUANTIZED(reference_ops); @@ -128,6 +137,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION_QUANTIZED(optimized_ops); } break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int64_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, int64_t); + } + break; + default: context->ReportError(context, "Only float32 and uint8 are currently supported."); diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 6f174763dfab9845d991b930e44b07a95e00d824..04c0263b789e75727ed3bd4d6b3292063a4530e0 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -256,10 +256,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - TF_LITE_ENSURE(context, real_multiplier < 1.0); - QuantizeMultiplierSmallerThanOneExp( - real_multiplier, &data->output_multiplier, &data->output_shift); - data->output_shift *= -1; + + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 0dcfc826fd218d2d2dfbf89201d2c13fbfe6f0e1..24633c2fd7cb3725977ae6c6459daa829165ccfd 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -64,12 +64,6 @@ class BaseConvolutionOpModel : public SingleOpModel { } output_ = AddOutput(output); - if (input.type != TensorType_FLOAT32) { - // The following is required by quantized inference. It is the unittest's - // responsibility to make sure the output scale falls into the correct - // range. - CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); - } SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, CreateConv2DOptions( @@ -441,6 +435,44 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantized) { })); } +TEST_P(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) { + // output_multiplier = 1.0118 + QuantizedConvolutionOpModel quant_op( + GetRegistration(), {TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128}, + {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128}, + {TensorType_UINT8, {}, -127, 128}); + ConvolutionOpModel float_op( + GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); + std::initializer_list input = { + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }; + std::initializer_list filter = { + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }; + std::initializer_list bias = {1, 2, 3}; + + quant_op.SetInput(input); + quant_op.SetFilter(filter); + quant_op.SetBias(bias); + quant_op.Invoke(); + + float_op.SetInput(input); + float_op.SetFilter(filter); + float_op.SetBias(bias); + float_op.Invoke(); + + EXPECT_THAT(quant_op.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); +} + TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { QuantizedConvolutionOpModel m(GetRegistration(), {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc index 0c532cac5a9f59c8b09ff9aefc294e243561f027..d7bde0ff79bd23fa4c277dd04ec4343663e0ad00 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -40,8 +40,8 @@ constexpr int kOutputTensorDetectionClasses = 1; constexpr int kOutputTensorDetectionScores = 2; constexpr int kOutputTensorNumDetections = 3; -constexpr size_t kNumCoordBox = 4; -constexpr size_t kBatchSize = 1; +constexpr int kNumCoordBox = 4; +constexpr int kBatchSize = 1; // Object Detection model produces axis-aligned boxes in two formats: // BoxCorner represents the upper right (xmin, ymin) and diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 59bab3c4ecd20bf938919ca606a5933f3112f233..e19779ea59d441984d3562508e4237e10ce17515 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -22,79 +22,118 @@ namespace tflite { namespace ops { namespace builtin { namespace elementwise { +namespace { +bool IsNumericSupportedType(const TfLiteType type) { + return type == kTfLiteFloat32; +} + +bool IsLogicalSupportedType(const TfLiteType type) { + return type == kTfLiteBool; +} + +typedef bool (*IsSupportedType)(TfLiteType); +template TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - // Quantized float is not supported yet. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + if (!IsSupportedType(input->type)) { + context->ReportError(context, "Current data type %d is not supported.", + input->type); + return kTfLiteError; + } return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); } -inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, - float float_func(float)) { +template +inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, + T func(T), TfLiteType expected_type) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - switch (input->type) { - case kTfLiteFloat32: { - size_t elements = NumElements(input); - const float* in = GetTensorData(input); - const float* in_end = in + elements; - float* out = output->data.f; - for (; in < in_end; in++, out++) *out = float_func(*in); - return kTfLiteOk; - } - default: { - context->ReportError(context, "Input type is %d, requires float32", - input->type); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, input->type, expected_type); + const int64_t num_elements = NumElements(input); + const T* in_data = GetTensorData(input); + T* out_data = GetTensorData(output); + for (int64_t i = 0; i < num_elements; ++i) { + out_data[i] = func(in_data[i]); } + return kTfLiteOk; +} + +inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { + return EvalImpl(context, node, float_func, kTfLiteFloat32); +} + +inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node, + bool bool_func(bool)) { + return EvalImpl(context, node, bool_func, kTfLiteBool); } TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::sin); + return EvalNumeric(context, node, std::sin); } TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::log); + return EvalNumeric(context, node, std::log); } TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::sqrt); + return EvalNumeric(context, node, std::sqrt); } TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, [](float f) { return 1.f / std::sqrt(f); }); + return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); +} + +TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { + return EvalLogical(context, node, [](bool v) { return !v; }); } +} // namespace } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::SinEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SinEval}; return &r; } TfLiteRegistration* Register_LOG() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::LogEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogEval}; return &r; } TfLiteRegistration* Register_SQRT() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::SqrtEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SqrtEval}; return &r; } TfLiteRegistration* Register_RSQRT() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::RsqrtEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::RsqrtEval}; + return &r; +} + +TfLiteRegistration* Register_LOGICAL_NOT() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogicalNotEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index ce4c602ee5c788d67701af3ecd3e023f2b25aae7..b9d7d73c52862da9166f6881b1e27a6ff6b76bbc 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,26 +24,40 @@ namespace { using ::testing::ElementsAreArray; -class ElementWiseOpModel : public SingleOpModel { +class ElementWiseOpBaseModel : public SingleOpModel { public: - ElementWiseOpModel(BuiltinOperator op, - std::initializer_list input_shape) { + int input() const { return input_; } + int output() const { return output_; } + + protected: + int input_; + int output_; +}; + +class ElementWiseOpFloatModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpFloatModel(BuiltinOperator op, + std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } +}; - int input() const { return input_; } - int output() const { return output_; } - - private: - int input_; - int output_; +class ElementWiseOpBoolModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpBoolModel(BuiltinOperator op, + std::initializer_list input_shape) { + input_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); + BuildInterpreter({input_shape}); + } }; TEST(ElementWise, Sin) { - ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -52,7 +66,7 @@ TEST(ElementWise, Sin) { } TEST(ElementWise, Log) { - ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -61,7 +75,7 @@ TEST(ElementWise, Log) { } TEST(ElementWise, Sqrt) { - ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 1, 2, 4}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -70,7 +84,7 @@ TEST(ElementWise, Sqrt) { } TEST(ElementWise, Rsqrt) { - ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {1, 2, 4, 9}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -78,6 +92,15 @@ TEST(ElementWise, Rsqrt) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, LogicalNot) { + ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {true, false, true, false}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({false, true, false, true})); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 3a855fe3ddaa7e7de0134f8dfee1ccf67168541a..0d424071da23010afe5f15a61e0ea6e45b4e6742 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -481,6 +481,9 @@ cc_library( ":darwin": [ ":neon_tensor_utils", ], + ":darwin_x86_64": [ + ":neon_tensor_utils", + ], "//conditions:default": [ ":portable_tensor_utils", ], diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index 310a8980e6943db3804b0671a21ccf0e6ce34c28..eb4d0108bd0438dd27744a864d071cfc166a7a94 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -117,6 +117,9 @@ template int CountLeadingZeros(T integer_input) { static_assert(std::is_unsigned::value, "Only unsigned integer types handled."); +#if defined(__GNUC__) + return integer_input ? __builtin_clz(integer_input) : 0; +#else const T one_in_leading_positive = static_cast(1) << (std::numeric_limits::digits - 1); int leading_zeros = 0; @@ -125,6 +128,7 @@ int CountLeadingZeros(T integer_input) { ++leading_zeros; } return leading_zeros; +#endif } // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h index d85e06a5d5af8d23235a08592d49754e4f493d34..250872c422a3ff9b3353d0055513ff1f7f03d68e 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -33,7 +33,7 @@ limitations under the License. #include #ifdef _WIN32 -#include +#include #elif defined(__APPLE__) #include #else diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 45c9f65b645616b516875b10155760d7c6ab59da..63c89d1eeef47b206fc871929f1fb1295b2f70ff 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -115,10 +115,10 @@ void ClipVector(const float* vector, int v_size, float abs_limit, } void SymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, float* max, - float* scaling_factor) { - NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, min, - max, scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor) { + NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, + min_value, max_value, scaling_factor); } void VectorShiftLeft(float* vector, int v_size, float shift_value) { diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 78567d52eaab779c724d3e3d04fbaf92fe6e589b..6adb879c71e6a02007dacd5ed9f91b04b2094fe7 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -168,6 +168,18 @@ ArrayMap MapAsArrayWithFirstDimAsRows(Scalar* data, return ArrayMap(data, rows, cols); } +// Copied from tensorflow/core/framework/tensor_types.h +template +struct TTypes { + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Flat; + typedef Eigen::TensorMap< + Eigen::Tensor> + UnalignedConstMatrix; +}; + // TODO(b/62193649): this function is only needed as long // as we have the --variable_batch hack. template @@ -1018,10 +1030,10 @@ inline void FullyConnectedAsGEMV( struct GemmlowpOutputPipeline { typedef gemmlowp::VectorMap ColVectorMap; - typedef std::tuple< - gemmlowp::OutputStageBiasAddition, - gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, - gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> + typedef std::tuple, + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent, + gemmlowp::OutputStageClamp, + gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; static Pipeline MakeExp(const int32* bias_data, int output_rows, int32 output_offset, int32 output_multiplier, @@ -1030,11 +1042,10 @@ struct GemmlowpOutputPipeline { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; - gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint - quantize_down_stage; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; quantize_down_stage.result_offset_after_shift = output_offset; quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; - quantize_down_stage.result_shift = -output_left_shift; + quantize_down_stage.result_exponent = output_left_shift; gemmlowp::OutputStageClamp clamp_stage; clamp_stage.min = output_activation_min; clamp_stage.max = output_activation_max; @@ -2315,7 +2326,8 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, ++*output_shift; } TFLITE_DCHECK_GT(input, 0); - const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bits = + CountLeadingZeros(static_cast(input)) - 1; const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; *output_shift -= left_shift_bit_pairs; @@ -4023,7 +4035,7 @@ inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, // perform a division by the above-computed sum-of-exponentials. int32 fixed_sum_of_exps = sum_of_exps.raw(); int headroom_plus_one = - __builtin_clz(static_cast(fixed_sum_of_exps)); + CountLeadingZeros(static_cast(fixed_sum_of_exps)); // This is the number of bits to the left of the binary point above 1.0. // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and // no later adjustment will be needed. @@ -4169,7 +4181,7 @@ log_x_for_x_greater_than_or_equal_to_1_impl( // required shift "ourselves" instead of using, say, Rescale. FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); // z_a_pow_2 = input_integer_bits - z_a_headroom; - int z_a_headroom_plus_1 = __builtin_clz(static_cast(z_a.raw())); + int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); FixedPoint0 r_a_tmp = SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); const int32 r_a_raw = @@ -4184,7 +4196,7 @@ log_x_for_x_greater_than_or_equal_to_1_impl( // z_b is treated like z_a, but premultiplying by sqrt(0.5). FixedPoint0 z_b = z_a * sqrt_half; - int z_b_headroom = __builtin_clz(static_cast(z_b.raw())) - 1; + int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; const int32 r_b_raw = SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); const FixedPointAccum z_b_pow_2_adj = SaturatingSub( diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index db7926df9af955df8f9d2a87898b35e1ea2c962e..010b40b901e2821c36367da7e2c987fac113de11 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -19,6 +19,10 @@ limitations under the License. // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + #ifndef USE_NEON #if defined(__ARM_NEON__) || defined(__ARM_NEON) #define USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index e224980493aa11f642da103ee7d7377b6c4b1da0..f882f9910e0c65d69eb5a86886bae4d3c881e6ab 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -109,12 +109,12 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) { void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, float* nudged_min, float* nudged_max, - float* scale) { + float* nudged_scale) { // This code originates from tensorflow/core/kernels/fake_quant_ops_functor.h. const float quant_min_float = static_cast(quant_min); const float quant_max_float = static_cast(quant_max); - *scale = (max - min) / (quant_max_float - quant_min_float); - const float zero_point_from_min = quant_min_float - min / *scale; + *nudged_scale = (max - min) / (quant_max_float - quant_min_float); + const float zero_point_from_min = quant_min_float - min / *nudged_scale; uint16 nudged_zero_point; if (zero_point_from_min < quant_min_float) { nudged_zero_point = static_cast(quant_min); @@ -123,8 +123,25 @@ void NudgeQuantizationRange(const float min, const float max, } else { nudged_zero_point = static_cast(TfLiteRound(zero_point_from_min)); } - *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); - *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); + *nudged_min = (quant_min_float - nudged_zero_point) * (*nudged_scale); + *nudged_max = (quant_max_float - nudged_zero_point) * (*nudged_scale); +} + +void FakeQuantizeArray(const float nudged_scale, const float nudged_min, + const float nudged_max, const float* input_data, + float* output_data, const float size) { + // This code originates from tensorflow/core/kernels/fake_quant_ops_functor.h. + const float inv_nudged_scale = 1.0f / nudged_scale; + + for (int i = 0; i < size; i++) { + const float src_val = input_data[i]; + const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); + const float clamped_shifted = clamped - nudged_min; + const float dst_val = + TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + + nudged_min; + output_data[i] = dst_val; + } } bool CheckedLog2(const float x, int* log2_result) { diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 9b3f1823dc7e08562d8906346bc44e4478642ddc..9ee4a47fbb5bba1a409830f99c7b9ba967325a0a 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -222,7 +222,15 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift); // Outputs nudged_min, nudged_max, nudged_scale. void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, - float* nudged_min, float* nudged_max, float* scale); + float* nudged_min, float* nudged_max, + float* nudged_scale); + +// Fake quantizes (quantizes and dequantizes) input_data using the scale, +// nudged_min, and nudged_max from NudgeQuantizationRange. This matches the code +// in TensorFlow's FakeQuantizeWithMinMaxVarsFunctor. +void FakeQuantizeArray(const float nudged_scale, const float nudged_min, + const float nudged_max, const float* input_data, + float* output_data, const float size); // If x is approximately a power of two (with any positive or negative // exponent), stores that exponent (i.e. log2(x)) in *log2_result, otherwise diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index 7ead449ca84ec2e5bbff9c2be6c6b1c6fe3b018c..a5f4addd5e9297124ff4f6d6011093fac101f9f0 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/round.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -37,15 +42,13 @@ bool PortableIsZeroVector(const float* vector, int v_size) { } void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, - float* __restrict__ min, - float* __restrict__ max, - float* __restrict__ scaling_factor) { + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor) { auto minmax = std::minmax_element(values, values + size); - *min = *minmax.first; - *max = *minmax.second; + *min_value = *minmax.first; + *max_value = *minmax.second; const int kScale = 127; - const float range = std::max(std::abs(*min), std::abs(*max)); + const float range = std::max(std::abs(*min_value), std::abs(*max_value)); if (range == 0) { memset(quantized_values, 0, size * sizeof(int8_t)); *scaling_factor = 1; @@ -81,9 +84,8 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, void PortableMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, - const float* __restrict__ scaling_factors, int n_batch, - float* __restrict__ result, int result_stride) { + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, float* __restrict__ result, int result_stride) { int batch, row, col; for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) { const float batch_scaling_factor = scaling_factors[batch]; @@ -92,9 +94,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate( for (row = 0; row < m_rows; ++row, result += result_stride) { // Initialize the dot product sum for the row to 0. int32_t dotprod = 0; +#if defined(__GNUC__) // Prefetch the row to cache. __builtin_prefetch(row_ptr, 0 /* prefetch for read */, 3 /* temporal locality */); +#endif // For every block of 16 8-bit elements (128-bit register) from each row. for (col = 0; col < m_cols; ++col, ++row_ptr) { dotprod += (*row_ptr) * (vectors[col]); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index d3a4fa8507c19b2ced513a115148dc9f3fda3714..a375aaffa67ac19975cc8e371be11547d689dc72 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -19,6 +19,10 @@ limitations under the License. // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -28,8 +32,8 @@ float PortableClip(float f, float abs_limit); bool PortableIsZeroVector(const float* vector, int v_size); void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, - float* max, float* scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor); // Multiply a matrix by a batch vector, and store results in a batch-size // vector. diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 31a54c2b62efcf8b6c066ba9d8b49d5ff8306854..ace3af2da06b31cea7f3d7e60d086f9ff6d7c0ce 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -322,8 +322,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( - acc, output_multiplier, kReverseShift * output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -903,7 +903,8 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, ++*output_shift; } TFLITE_DCHECK_GT(input, 0); - const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bits = + CountLeadingZeros(static_cast(input)) - 1; const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; *output_shift -= left_shift_bit_pairs; @@ -3155,18 +3156,9 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, float nudged_min, nudged_max, nudged_scale; NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, &nudged_max, &nudged_scale); - const float inv_nudged_scale = 1.0f / nudged_scale; - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; i++) { - const float src_val = input_data[i]; - const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); - const float clamped_shifted = clamped - nudged_min; - const float dst_val = - TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + - nudged_min; - output_data[i] = dst_val; - } + FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data, + output_data, flat_size); } template @@ -3284,7 +3276,8 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const Dims<4>& block_shape_dims, const int32* paddings_data, const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, + const int32_t pad_value) { const int output_batch_size = ArraySize(output_dims, 3); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); @@ -3309,7 +3302,7 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, padding_top + input_height || out_w * block_shape_width + shift_w < padding_left || out_w * block_shape_width + shift_w >= padding_left + input_width) { - memset(out, 0, depth * sizeof(T)); + memset(out, pad_value, depth * sizeof(T)); } else { const T* in = input_data + @@ -3324,6 +3317,17 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, } } +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims, + paddings_data, paddings_dims, output_data, output_dims, 0); +} + template inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -4178,8 +4182,8 @@ inline void RankOneSelect(const D* input_condition_data, } // For easy implementation, the indices is always a vector of size-4 vectors. -template -inline void SparseToDense(const std::vector>& indices, +template +inline void SparseToDense(const std::vector>& indices, const T* values, T default_value, T* output_data, const Dims<4>& output_dims, bool value_is_scalar) { const int value_count = indices.size(); @@ -4194,7 +4198,7 @@ inline void SparseToDense(const std::vector>& indices, // condition within the loop every time. if (value_is_scalar) { for (int i = 0; i < value_count; ++i) { - const std::vector& index = indices[i]; + const std::vector& index = indices[i]; TFLITE_DCHECK_EQ(index.size(), 4); const T value = *values; // just use the first value. output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = @@ -4205,7 +4209,7 @@ inline void SparseToDense(const std::vector>& indices, // Go through the values and indices to fill the sparse values. for (int i = 0; i < value_count; ++i) { - const std::vector& index = indices[i]; + const std::vector& index = indices[i]; TFLITE_DCHECK_EQ(index.size(), 4); const T value = values[i]; output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = @@ -4243,6 +4247,65 @@ inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, } } +inline void Logical(const bool* input1_data, const Dims<4>& input1_dims, + const bool* input2_data, const Dims<4>& input2_dims, + bool* output_data, const Dims<4>& output_dims, + const std::function& func) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = func(input1_data[i], input2_data[i]); + } +} + +inline void BroadcastLogical(const bool* input1_data, + const Dims<4>& input1_dims, + const bool* input2_data, + const Dims<4>& input2_dims, bool* output_data, + const Dims<4>& output_dims, + const std::function& func) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more +// generalized and efficient BroadcastBinaryFunction. +// +// R: Result type. T1: Input 1 type. T2: Input 2 type. +template +inline void BroadcastBinaryFunction(const T1* input1_data, + const Dims<4>& input1_dims, + const T2* input2_data, + const Dims<4>& input2_dims, R* output_data, + const Dims<4>& output_dims, + R (*func)(T1, T2)) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc index 4eddf7bf0a2cbca695dae20ba8ba56a9cd72e4ba..20abcb725859d03f83c969369bddf1429895e0ba 100644 --- a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc +++ b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc @@ -43,13 +43,13 @@ bool Spectrogram::Initialize(int window_length, int step_length) { return Initialize(window, step_length); } -inline int Log2Floor(uint n) { +inline int Log2Floor(uint32_t n) { if (n == 0) return -1; int log = 0; - uint value = n; + uint32_t value = n; for (int i = 4; i >= 0; --i) { int shift = (1 << i); - uint x = value >> shift; + uint32_t x = value >> shift; if (x != 0) { value = x; log += shift; @@ -58,7 +58,7 @@ inline int Log2Floor(uint n) { return log; } -inline int Log2Ceiling(uint n) { +inline int Log2Ceiling(uint32_t n) { int floor = Log2Floor(n); if (n == (n & ~(n - 1))) // zero or a power of two return floor; @@ -66,7 +66,7 @@ inline int Log2Ceiling(uint n) { return floor + 1; } -inline uint NextPowerOfTwo(uint value) { +inline uint32_t NextPowerOfTwo(uint32_t value) { int exponent = Log2Ceiling(value); // DCHECK_LT(exponent, std::numeric_limits::digits); return 1 << exponent; diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 82f450312784a1864dc7732dad2a75d2d6ae90f4..1ff8cfe39c9aed7e9241815dc8eff7ab4d9fd585 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -17,6 +17,10 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -31,8 +35,8 @@ bool IsZeroVector(const float* vector, int v_size); // It also outputs the range (min, max) of the floating point buffer, and the // scaling factor used to quantize the values. void SymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, float* max, - float* scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor); // Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch // dimension composed by input vectors independent from each other). The result diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc new file mode 100644 index 0000000000000000000000000000000000000000..87c2fee667ccaf7bfdc4e2316309d2abc35b5324 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace logical { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for logical op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Reinterprete the opaque data provided by user. + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteBool) { + context->ReportError(context, "Logical ops only support bool type."); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node, + const std::function& func) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (data->requires_broadcast) { + reference_ops::BroadcastLogical( + GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output), func); + } else { + reference_ops::Logical(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output), + func); + } + + return kTfLiteOk; +} + +TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_or_func = std::logical_or(); + return LogicalImpl(context, node, logical_or_func); +} + +TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_and_func = std::logical_and(); + return LogicalImpl(context, node, logical_and_func); +} + +} // namespace +} // namespace logical + +TfLiteRegistration* Register_LOGICAL_OR() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalOrEval}; + return &r; +} + +TfLiteRegistration* Register_LOGICAL_AND() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalAndEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/contrib/lite/kernels/logical_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..206cbde98fa48ec5f7c863bbced9dccc9cab5207 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class LogicalOpModel : public SingleOpModel { + public: + LogicalOpModel(std::initializer_list input1_shape, + std::initializer_list input2_shape, BuiltinOperator op) { + input1_ = AddInput(TensorType_BOOL); + input2_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + ConfigureBuiltinOp(op); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_LOGICAL_OR: { + SetBuiltinOp(op, BuiltinOptions_LogicalOrOptions, + CreateLogicalOrOptions(builder_).Union()); + break; + } + case BuiltinOperator_LOGICAL_AND: { + SetBuiltinOp(op, BuiltinOptions_LogicalAndOptions, + CreateLogicalAndOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } +}; + +TEST(LogicalTest, LogicalOr) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_OR); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true, false, true, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, BroadcastLogicalOr) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_OR); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, LogicalAnd) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_AND); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true, false, true, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, BroadcastLogicalAnd) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_AND); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ff3dca932d4284321b299cfc79571c43fce7155 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/one_hot.cc @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace one_hot { + +constexpr int kIndicesTensor = 0; +constexpr int kDepthTensor = 1; +constexpr int kOnValueTensor = 2; +constexpr int kOffValueTensor = 3; +constexpr int kOutputTensor = 0; + +// Convenience utility for destructuring a node into the appropriate tensors and +// data for the op. Note that this destructuring is quite cheap, so we can avoid +// allocating op-specific, persistent data on the heap. +struct OneHotContext { + OneHotContext(TfLiteContext* context, TfLiteNode* node) { + indices = GetInput(context, node, kIndicesTensor); + depth = GetInput(context, node, kDepthTensor); + on_value = GetInput(context, node, kOnValueTensor); + off_value = GetInput(context, node, kOffValueTensor); + output = GetOutput(context, node, kOutputTensor); + + const auto* params = + reinterpret_cast(node->builtin_data); + const int indices_dims = indices->dims->size; + axis = (params->axis == -1) ? indices_dims : params->axis; + output_dims = indices_dims + 1; + dtype = on_value->type; + } + + const TfLiteTensor* indices; + const TfLiteTensor* depth; + const TfLiteTensor* on_value; + const TfLiteTensor* off_value; + TfLiteTensor* output; + int axis; + int output_dims; + TfLiteType dtype; +}; + +template +void OneHotComputeImpl(const OneHotContext& op_context) { + // prefix_dim_size == # of elements before the axis + // depth == # of elements per axis + // suffix_dim_size == # of elements after the axis + int prefix_dim_size = 1; + for (int i = 0; i < op_context.axis; ++i) { + prefix_dim_size *= op_context.indices->dims->data[i]; + } + const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size; + const int depth = *op_context.depth->data.i32; + + const T on_value = *GetTensorData(op_context.on_value); + const T off_value = *GetTensorData(op_context.off_value); + + // View the indices as a matrix of size: + // prefix_dim_size x suffix_dim_size + // View the output as a matrix of size: + // prefix_dim_size x depth x suffix_dim_size + // Then the output is: + // output(i, j, k) == (indices(i, k) == j) ? on : off + T* output = GetTensorData(op_context.output); + const TI* indices = GetTensorData(op_context.indices); + for (int i = 0; i < prefix_dim_size; ++i) { + for (int j = 0; j < depth; ++j) { + for (int k = 0; k < suffix_dim_size; ++k, ++output) { + *output = static_cast(indices[i * suffix_dim_size + k]) == j + ? on_value + : off_value; + } + } + } +} + +template +void OneHotCompute(const OneHotContext& op_context) { + if (op_context.indices->type == kTfLiteInt64) { + OneHotComputeImpl(op_context); + } else { + OneHotComputeImpl(op_context); + } +} + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const OneHotContext& op_context) { + TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0); + TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims); + for (int i = 0; i < op_context.output_dims; ++i) { + if (i < op_context.axis) { + output_size->data[i] = op_context.indices->dims->data[i]; + } else if (i == op_context.axis) { + output_size->data[i] = *op_context.depth->data.i32; + } else { + output_size->data[i] = op_context.indices->dims->data[i - 1]; + } + } + return context->ResizeTensor(context, op_context.output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OneHotContext op_context{context, node}; + switch (op_context.dtype) { + // TODO(b/111744875): Support uint8 and quantization. + case kTfLiteFloat32: + case kTfLiteInt16: + case kTfLiteInt32: + case kTfLiteInt64: + case kTfLiteBool: + op_context.output->type = op_context.dtype; + break; + default: + context->ReportError(context, "Unknown output data type: %d", + op_context.dtype); + return kTfLiteError; + } + + TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 || + op_context.indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, op_context.axis >= 0 && + op_context.axis < op_context.output_dims); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1); + TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype); + TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype); + + if (!IsConstantTensor(op_context.depth)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + + return ResizeOutputTensor(context, op_context); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OneHotContext op_context{context, node}; + + if (IsDynamicTensor(op_context.output)) { + ResizeOutputTensor(context, op_context); + } + + switch (op_context.output->type) { + case kTfLiteFloat32: + OneHotCompute(op_context); + break; + case kTfLiteInt32: + OneHotCompute(op_context); + break; + case kTfLiteInt64: + OneHotCompute(op_context); + break; + case kTfLiteBool: + OneHotCompute(op_context); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace one_hot + +TfLiteRegistration* Register_ONE_HOT() { + static TfLiteRegistration r = { + nullptr, + nullptr, + one_hot::Prepare, + one_hot::Eval, + }; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/one_hot_test.cc b/tensorflow/contrib/lite/kernels/one_hot_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b604ec7a7f86b333805d91a95cb5054f0257ae4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/one_hot_test.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class OneHotOpModel : public SingleOpModel { + public: + OneHotOpModel(std::initializer_list input_shape, int depth_value, + TensorType dtype, int axis = -1, T on_value = 1, + T off_value = 0, TensorType indices_type = TensorType_INT32) { + indices_ = AddInput(indices_type); + int depth = AddInput(TensorType_INT32); + int on = AddInput(dtype); + int off = AddInput(dtype); + output_ = AddOutput(dtype); + SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions, + CreateOneHotOptions(builder_, axis).Union()); + BuildInterpreter({input_shape}); + + PopulateTensor(depth, {depth_value}); + PopulateTensor(on, {on_value}); + PopulateTensor(off, {off_value}); + } + + template + void SetIndices(std::initializer_list data) { + PopulateTensor(indices_, data); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int32_t GetOutputSize() { return GetTensorSize(output_); } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int indices_; + int output_; +}; + +TEST(OneHotOpTest, BasicFloat) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_FLOAT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f})); +} + +TEST(OneHotOpTest, BasicInt) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_INT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1})); +} + +TEST(OneHotOpTest, BasicBool) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_BOOL); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({true, false, false, false, true, false, false, + false, true})); +} + +TEST(OneHotOpTest, SmallDepth) { + const int depth = 1; + OneHotOpModel model({3}, depth, TensorType_INT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0})); +} + +TEST(OneHotOpTest, BigDepth) { + const int depth = 4; + OneHotOpModel model({2}, depth, TensorType_INT32); + model.SetIndices({0, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0})); +} + +TEST(OneHotOpTest, OnOffValues) { + const int depth = 3; + const int axis = -1; + const int on = 5; + const int off = 0; + OneHotOpModel model({4}, depth, TensorType_INT32, axis, on, off); + model.SetIndices({0, 2, -1, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0})); +} + +TEST(OneHotOpTest, ZeroAxis) { + const int depth = 3; + const int axis = 0; + const int on = 5; + const int off = 0; + OneHotOpModel model({4}, depth, TensorType_INT32, axis, on, off); + model.SetIndices({0, 2, -1, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0})); +} + +TEST(OneHotOpTest, MultiDimensionalIndices) { + const int depth = 3; + const int axis = -1; + const float on = 2; + const float off = 0; + OneHotOpModel model({2, 2}, depth, TensorType_FLOAT32, axis, on, off); + model.SetIndices({0, 2, 1, -1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0})); +} + +TEST(OneHotOpTest, Int64Indices) { + const int depth = 3; + const int axis = -1; + const int on = 1; + const int off = 0; + OneHotOpModel model({3}, depth, TensorType_INT32, axis, on, off, + TensorType_INT64); + std::initializer_list indices = {0, 1, 2}; + model.SetIndices(indices); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 0b70bed30899f95eaafdc9afd426144751f7db90..8d2c108116e1666f342392ada44854190a5b80ee 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -107,6 +107,37 @@ TfLiteRegistration* Register_SHAPE(); TfLiteRegistration* Register_POW(); TfLiteRegistration* Register_FAKE_QUANT(); TfLiteRegistration* Register_PACK(); +TfLiteRegistration* Register_ONE_HOT(); +TfLiteRegistration* Register_LOGICAL_OR(); +TfLiteRegistration* Register_LOGICAL_AND(); +TfLiteRegistration* Register_LOGICAL_NOT(); + +TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { + context->ReportError( + context, + "Regular TensorFlow ops are not supported by this interpreter. Make sure " + "you invoke the Eager delegate before inference."); + return kTfLiteError; +} + +const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op, + int version) const { + return MutableOpResolver::FindOp(op, version); +} + +const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op, + int version) const { + // Return the NULL Op for all ops whose name start with "Eager:", allowing + // the interpreter to delegate their execution. + if (string(op).find("Eager:") == 0) { + static TfLiteRegistration null_op{ + nullptr, nullptr, &UnsupportedTensorFlowOp, + nullptr, nullptr, BuiltinOperator_CUSTOM, + "Eager", 1}; + return &null_op; + } + return MutableOpResolver::FindOp(op, version); +} BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -197,6 +228,10 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_POW, Register_POW()); AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); AddBuiltin(BuiltinOperator_PACK, Register_PACK()); + AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); + AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); + AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); + AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index 940718d67e70b7206227b891ea529cb9e9619161..0296152d68d6836fd592a65eeea69a7d4ebbb6ef 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -26,6 +26,10 @@ namespace builtin { class BuiltinOpResolver : public MutableOpResolver { public: BuiltinOpResolver(); + + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; }; } // namespace builtin diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index 99ecc16093372fcda3257eef6099af64770c1c27..49ba0571e2f214c0b2407240753fcec0661c71bf 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -37,10 +37,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node, // special -1 value, meaning it will be calculated automatically based on the // input. Here we calculate what that dimension should be so that the number // of output elements in the same as the number of input elements. - int num_input_elements = 1; - for (int i = 0; i < NumDimensions(input); ++i) { - num_input_elements *= SizeOfDimension(input, i); - } + int num_input_elements = NumElements(input); int num_output_elements = 1; int stretch_dim = -1; @@ -96,9 +93,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // The function is returned above this line if the shape tensor is usable. // Now fallback to the shape parameter in `TfLiteReshapeParams`. - - TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions); - for (int i = 0; i < params->num_dimensions; ++i) { + int num_dimensions = params->num_dimensions; + if (num_dimensions == 1 && params->shape[0] == 0) { + // Legacy tflite models use a shape parameter of [0] to indicate scalars, + // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during + // toco conversion. + num_dimensions = 0; + } + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { output_shape->data[i] = params->shape[i]; } return ResizeOutput(context, node, output_shape); diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc index aecbd0399f7454045e8189072f45b695b0525204..52d71350d3ba9a27bf9a8df7a194161c4fb7f87c 100644 --- a/tensorflow/contrib/lite/kernels/reshape_test.cc +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -22,18 +22,27 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::IsEmpty; class ReshapeOpModel : public SingleOpModel { public: ReshapeOpModel(std::initializer_list input_shape, - std::initializer_list new_shape) { + std::initializer_list new_shape, + bool use_shape_input_tensor = false) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); + int shape_input_tensor = + use_shape_input_tensor ? AddInput(TensorType_INT32) : -1; SetBuiltinOp( BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) .Union()); - BuildInterpreter({input_shape}); + if (use_shape_input_tensor) { + BuildInterpreter({input_shape, GetShape(shape_input_tensor)}); + PopulateTensor(shape_input_tensor, new_shape); + } else { + BuildInterpreter({input_shape}); + } } void SetInput(std::initializer_list data) { @@ -71,6 +80,14 @@ TEST(ReshapeOpTest, SimpleTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); } +TEST(ReshapeOpTest, ShapeTensorInput) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}, /*use_shape_input_tensor=*/true); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + TEST(ReshapeOpTest, WithStretchDimension) { ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); @@ -79,6 +96,22 @@ TEST(ReshapeOpTest, WithStretchDimension) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); } +TEST(ReshapeOpTest, ScalarOutput) { + ReshapeOpModel m({1}, {}); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + +TEST(ReshapeOpTest, LegacyScalarOutput) { + ReshapeOpModel m({1}, {0}); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 10caffea03ebcec7862df1627541ac3d076b04e4..f4289105f7931ae572f219a61b5479287aff926a 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -247,7 +247,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 3, 6, // 9, 12, // 4, 10, // - 10, 16 // + 12, 16 // }); m.SetSize({3, 3}); m.Invoke(); @@ -256,8 +256,8 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 7, 9, 10, // 9, 11, 12, // 4, 8, 10, // - 8, 12, 14, // - 10, 13, 16, // + 9, 12, 14, // + 12, 14, 16, // }))); ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}); @@ -265,7 +265,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 3, 6, // 9, 12, // 4, 10, // - 10, 16 // + 12, 16 // }); const_m.Invoke(); EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ @@ -273,35 +273,35 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 7, 9, 10, // 9, 11, 12, // 4, 8, 10, // - 8, 12, 14, // - 10, 13, 16, // + 9, 12, 14, // + 12, 14, 16, // }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); m.SetInput({ - 3, 4, 6, 10, // - 9, 10, 12, 16, // + 3, 4, 6, 10, // + 10, 12, 14, 16, // }); m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 13, 12, 16, // + 3, 4, 5, 8, 6, 10, // + 7, 9, 10, 12, 11, 14, // + 10, 12, 12, 14, 14, 16, // }))); ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}); const_m.SetInput({ - 3, 4, 6, 10, // - 9, 10, 12, 16, // + 3, 4, 6, 10, // + 10, 12, 14, 16, // }); const_m.Invoke(); EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 13, 12, 16, // + 3, 4, 5, 8, 6, 10, // + 7, 9, 10, 12, 11, 14, // + 10, 12, 12, 14, 14, 16, // }))); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index c9269599e58f095ded4788e2ab064583ae0a708c..03079f1c3b4110da9193f91ed22940594152b10f 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -113,7 +113,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); } -#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ +#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \ type::SpaceToBatchND(GetTensorData(op_context.input), \ GetTensorDims(op_context.input), \ GetTensorData(op_context.block_shape), \ @@ -121,34 +121,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorData(op_context.paddings), \ GetTensorDims(op_context.paddings), \ GetTensorData(op_context.output), \ - GetTensorDims(op_context.output)) + GetTensorDims(op_context.output), pad_value) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0); } break; case kTfLiteUInt8: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t, + op_context.output->params.zero_point); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t, + op_context.output->params.zero_point); } break; case kTfLiteInt32: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0); } break; case kTfLiteInt64: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0); } break; default: diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc index 92a4a037d5873e608ee7bdbdfc5eaa5e9b62bc8c..5756573629a51917e39a312117a1fcd29c150dc0 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc @@ -23,6 +23,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::Matcher; class SpaceToBatchNDOpModel : public SingleOpModel { public: @@ -30,6 +31,10 @@ class SpaceToBatchNDOpModel : public SingleOpModel { PopulateTensor(input_, data); } + void SetQuantizedInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + void SetBlockShape(std::initializer_list data) { PopulateTensor(block_shape_, data); } @@ -41,6 +46,11 @@ class SpaceToBatchNDOpModel : public SingleOpModel { std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + protected: int input_; int block_shape_; @@ -56,18 +66,19 @@ class SpaceToBatchNDOpModel : public SingleOpModel { // m.Invoke(); class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { public: - SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + SpaceToBatchNDOpConstModel(const TensorData& input, std::initializer_list block_shape, - std::initializer_list paddings) { - input_ = AddInput(TensorType_FLOAT32); + std::initializer_list paddings, + const TensorData& output) { + input_ = AddInput(input); block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions, CreateSpaceToBatchNDOptions(builder_).Union()); - BuildInterpreter({input_shape}); + BuildInterpreter({input.shape}); } }; @@ -81,26 +92,30 @@ class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { // m.Invoke(); class SpaceToBatchNDOpDynamicModel : public SpaceToBatchNDOpModel { public: - SpaceToBatchNDOpDynamicModel(std::initializer_list input_shape) { - input_ = AddInput(TensorType_FLOAT32); + SpaceToBatchNDOpDynamicModel(const TensorData& input, + const TensorData& output) { + input_ = AddInput(input); block_shape_ = AddInput(TensorType_INT32); paddings_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions, CreateSpaceToBatchNDOptions(builder_).Union()); - BuildInterpreter({input_shape, {2}, {2, 2}}); + BuildInterpreter({input.shape, {2}, {2, 2}}); } }; TEST(SpaceToBatchNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(SpaceToBatchNDOpConstModel({1, 3, 3, 1}, {2, 2}, {0, 0, 0, 0}), - "Cannot allocate tensors"); + EXPECT_DEATH( + SpaceToBatchNDOpConstModel({TensorType_FLOAT32, {1, 3, 3, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}), + "Cannot allocate tensors"); } TEST(SpaceToBatchNDOpTest, SimpleConstTest) { - SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 4, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); @@ -109,7 +124,8 @@ TEST(SpaceToBatchNDOpTest, SimpleConstTest) { } TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 4, 4, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 4, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); m.SetPaddings({0, 0, 0, 0}); @@ -120,7 +136,8 @@ TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { } TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { - SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); @@ -129,7 +146,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { } TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { - SpaceToBatchNDOpDynamicModel m({2, 2, 4, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); m.SetPaddings({0, 0, 0, 0}); @@ -140,7 +158,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { } TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { - SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 5, 2, 1}}, {3, 2}, + {1, 0, 2, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); @@ -151,7 +170,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { } TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 5, 2, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 5, 2, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); m.SetBlockShape({3, 2}); m.SetPaddings({1, 0, 2, 0}); @@ -164,7 +184,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { } TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { - SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 2, 1}}, {3, 2}, + {1, 1, 2, 4}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); @@ -176,7 +197,8 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { } TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 4, 2, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 2, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); m.SetBlockShape({3, 2}); m.SetPaddings({1, 1, 2, 4}); @@ -189,6 +211,88 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { })); } +class QuantizedSpaceToBatchNDOpTest : public ::testing::Test { + protected: + std::vector> DequantizedArrayNear( + const std::vector& values, const float min, const float max) { + const float quantization_tolerance = (max - min) / 255.0; + return ArrayFloatNear(values, quantization_tolerance); + } +}; + +TEST_F(QuantizedSpaceToBatchNDOpTest, ZeroNotInQuantizationRange) { + // The test_util and actual quantization code currently ensure that the range + // must include zero, but if that ever changes, this test will catch it. + EXPECT_DEATH(SpaceToBatchNDOpConstModel m( + {TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, {TensorType_UINT8, {}, 1.0, 2.0}), + ".*Check failed: f_min <= 0.*"); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0}, + {3, 2}, {1, 0, 2, 0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, + 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 0, 2, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, + 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0}, + {3, 2}, {1, 1, 2, 4}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + { + 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0, + 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, + 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0, + }, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + { + 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0, + 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, + 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0, + }, + -1.0, 1.0))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc index 7be5e66c166cd752fc325f25d38e6522948e0f06..fec2a6f0d97ae48e0c49d82c726278a46d96a7fc 100644 --- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -187,7 +187,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return ResizeOutputShape(context, output_shape, output); } -template +template TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); const TfLiteTensor* output_shape = @@ -204,10 +204,10 @@ TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { const int num_indices = SizeOfDimension(indices, 0); const bool value_is_scalar = NumDimensions(values) == 0; - std::vector> indices_vector; + std::vector> indices_vector; indices_vector.reserve(num_indices); - TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, - &indices_vector)); + TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, + &indices_vector)); reference_ops::SparseToDense(indices_vector, GetTensorData(values), *GetTensorData(default_value), GetTensorData(output), GetTensorDims(output), diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc index af77f074742eb3fef10a74616ff679255911fbb2..5181a8f89a376302bad02913e3c7c1d094821da8 100644 --- a/tensorflow/contrib/lite/kernels/tile.cc +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -87,8 +87,9 @@ std::pair TileOneDimension(const TfLiteIntArray& in_dimensions, if (dimension == in_dimensions.size - 1) { CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], out_data); - return std::make_pair(dimension_size, - dimension_size * multipliers[dimension]); + return std::make_pair( + dimension_size, + dimension_size * static_cast(multipliers[dimension])); } int total_stride_size = 0, total_tiled_stride_size = 0; const T* copy_from_data = in_data; diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa9a3cd1d839b07149bb80c3b7714b32b5eda235 --- /dev/null +++ b/tensorflow/contrib/lite/mmap_allocation.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/error_reporter.h" + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { + mmap_fd_ = open(filename, O_RDONLY); + if (mmap_fd_ == -1) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + struct stat sb; + fstat(mmap_fd_, &sb); + buffer_size_bytes_ = sb.st_size; + mmapped_buffer_ = + mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); + if (mmapped_buffer_ == MAP_FAILED) { + error_reporter_->Report("Mmap of '%s' failed.", filename); + return; + } +} + +MMAPAllocation::~MMAPAllocation() { + if (valid()) { + munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); + } + if (mmap_fd_ != -1) close(mmap_fd_); +} + +const void* MMAPAllocation::base() const { return mmapped_buffer_; } + +size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } + +bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } + +bool MMAPAllocation::IsSupported() { return true; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/mmap_allocation_disabled.cc b/tensorflow/contrib/lite/mmap_allocation_disabled.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3d4cf1a257d43ebd56cc9b8831de0bb1994d40c --- /dev/null +++ b/tensorflow/contrib/lite/mmap_allocation_disabled.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/allocation.h" + +#include + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(nullptr) { + // The disabled variant should never be created. + assert(false); +} + +MMAPAllocation::~MMAPAllocation() {} + +const void* MMAPAllocation::base() const { return nullptr; } + +size_t MMAPAllocation::bytes() const { return 0; } + +bool MMAPAllocation::valid() const { return false; } + +bool MMAPAllocation::IsSupported() { return false; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index c6869feb16040bd87965eeda8521b2151f445da5..9edf5ba38f4c6506524074bc0a3ebe7e068c7ee3 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -24,7 +23,9 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/model.h" +#ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" +#endif #include "tensorflow/contrib/lite/version.h" namespace tflite { @@ -73,6 +74,7 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, return kTfLiteOk; } +#ifndef TFLITE_MCU // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. std::unique_ptr GetAllocationFromFile(const char* filename, @@ -80,8 +82,8 @@ std::unique_ptr GetAllocationFromFile(const char* filename, ErrorReporter* error_reporter, bool use_nnapi) { std::unique_ptr allocation; - if (mmap_file) { - if (use_nnapi && NNAPIExists()) + if (mmap_file && MMAPAllocation::IsSupported()) { + if (use_nnapi && NNAPIDelegate::IsSupported()) allocation.reset(new NNAPIAllocation(filename, error_reporter)); else allocation.reset(new MMAPAllocation(filename, error_reporter)); @@ -120,6 +122,7 @@ std::unique_ptr FlatBufferModel::VerifyAndBuildFromFile( if (!model->initialized()) model.reset(); return model; } +#endif std::unique_ptr FlatBufferModel::BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { @@ -730,6 +733,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = static_cast(params); break; } + case BuiltinOperator_ONE_HOT: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { + params->axis = schema_params->axis(); + } + *builtin_data = static_cast(params); + break; + } // Below are the ops with no builtin_data strcture. case BuiltinOperator_BATCH_TO_SPACE_ND: @@ -773,6 +784,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_TRANSPOSE: case BuiltinOperator_POW: case BuiltinOperator_LOGICAL_OR: + case BuiltinOperator_LOGICAL_AND: + case BuiltinOperator_LOGICAL_NOT: break; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index edfdec9315aca449002ab98a34bf0a7dbb4ad4e6..df4f60d4ad4eb71f48eb3ad364f95f93b84f3d75 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -241,14 +241,6 @@ TEST(BasicFlatBufferModel, TestWithNullVerifier) { "tensorflow/contrib/lite/testdata/test_model.bin", nullptr)); } -struct TestErrorReporter : public ErrorReporter { - int Report(const char* format, va_list args) override { - calls++; - return 0; - } - int calls = 0; -}; - // This makes sure the ErrorReporter is marshalled from FlatBufferModel to // the Interpreter. TEST(BasicFlatBufferModel, TestCustomErrorReporter) { @@ -262,7 +254,7 @@ TEST(BasicFlatBufferModel, TestCustomErrorReporter) { TrivialResolver resolver; InterpreterBuilder(*model, resolver)(&interpreter); ASSERT_NE(interpreter->Invoke(), kTfLiteOk); - ASSERT_EQ(reporter.calls, 1); + ASSERT_EQ(reporter.num_calls(), 1); } // This makes sure the ErrorReporter is marshalled from FlatBufferModel to diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h index 90260c8d620b0e756f72089d3f4d8d9f92d44fbe..3151192d9277b6df513a76afb08af30d0379b7b1 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.h +++ b/tensorflow/contrib/lite/models/smartreply/predictor.h @@ -65,9 +65,9 @@ struct SmartReplyConfig { float backoff_confidence; // Backoff responses are used when predicted responses cannot fulfill the // list. - const std::vector& backoff_responses; + std::vector backoff_responses; - SmartReplyConfig(std::vector backoff_responses) + SmartReplyConfig(const std::vector& backoff_responses) : num_response(kDefaultNumResponse), backoff_confidence(kDefaultBackoffConfidence), backoff_responses(backoff_responses) {} diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 551e8ed3201506181452fc15ad6cffebb8aacfb1..13325a8c7c62142fcbc8ca37af1216e8943b905b 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -568,9 +568,17 @@ TfLiteStatus AddOpsAndParams( "NNAPI does not support L2Normalization with fused activations"); } break; + case tflite::BuiltinOperator_HASHTABLE_LOOKUP: + if (interpreter->tensor(node.outputs->data[0])->type != + kTfLiteFloat32) { + logError("NNAPI only support HASHTABLE_LOOKUP with float32 output", + builtin); + return kTfLiteError; + } + nn_op_type = ANEURALNETWORKS_HASHTABLE_LOOKUP; + break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: - case tflite::BuiltinOperator_HASHTABLE_LOOKUP: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: @@ -623,6 +631,9 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_FAKE_QUANT: case tflite::BuiltinOperator_PACK: case tflite::BuiltinOperator_LOGICAL_OR: + case tflite::BuiltinOperator_ONE_HOT: + case tflite::BuiltinOperator_LOGICAL_AND: + case tflite::BuiltinOperator_LOGICAL_NOT: logError("Op code %d is currently not delegated to NNAPI", builtin); return kTfLiteError; break; @@ -788,4 +799,6 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { return kTfLiteOk; } +bool NNAPIDelegate::IsSupported() { return NNAPIExists(); } + } // namespace tflite diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index 8dc7d38a303f51b7ccefefd8c9d2990b443e6827..2bdb2cc5c8211a48ea07e7ec45f9eebc0a3f7c10 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -19,9 +19,10 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" -class ANeuralNetworsModel; +class ANeuralNetworksModel; +class ANeuralNetworksMemory; +class ANeuralNetworksCompilation; namespace tflite { @@ -54,6 +55,9 @@ class NNAPIDelegate { // Run TfLiteStatus Invoke(Interpreter* interpreter); + // Whether the current platform supports NNAPI delegation. + static bool IsSupported(); + private: // The NN API model handle ANeuralNetworksModel* nn_model_ = nullptr; diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc new file mode 100644 index 0000000000000000000000000000000000000000..efde72b1a76a86728f4cccd8782ca0e993dd0338 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +#include + +namespace tflite { + +NNAPIAllocation::NNAPIAllocation(const char* filename, + ErrorReporter* error_reporter) + : MMAPAllocation(filename, error_reporter) { + // The disabled variant should never be created. + assert(false); +} + +NNAPIAllocation::~NNAPIAllocation() {} + +NNAPIDelegate::~NNAPIDelegate() {} + +TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { + return kTfLiteError; +} + +TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { + return kTfLiteError; +} + +bool NNAPIDelegate::IsSupported() { return false; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc index 446660bb747cd6e3b694669b64ac1d95cf415fbe..875ddb02bcfc30f4c2ef543fe1c15bec467e5410 100644 --- a/tensorflow/contrib/lite/profiling/time.cc +++ b/tensorflow/contrib/lite/profiling/time.cc @@ -14,16 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/profiling/time.h" +#if defined(_MSC_VER) +#include // NOLINT(build/c++11) +#else #include +#endif namespace tflite { namespace profiling { namespace time { + +#if defined(_MSC_VER) + +uint64_t NowMicros() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +#else + uint64_t NowMicros() { struct timeval tv; gettimeofday(&tv, nullptr); return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; } + +#endif // defined(_MSC_VER) + } // namespace time } // namespace profiling } // namespace tflite diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index f97919363bdc4900ae10c7a2b1a6e74e3d5728cc..9ab05f3068494a573ffa5b46f84be66a12d54e46 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -108,7 +108,9 @@ std::unique_ptr CreateInterpreter( ImportNumpy(); std::unique_ptr interpreter; - tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { + return nullptr; + } return interpreter; } @@ -182,13 +184,37 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { } // namespace +InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::string* error_msg) { + if (!model) { + *error_msg = error_reporter->message(); + return nullptr; + } + + auto resolver = absl::make_unique(); + auto interpreter = CreateInterpreter(model.get(), *resolver); + if (!interpreter) { + *error_msg = error_reporter->message(); + return nullptr; + } + + InterpreterWrapper* wrapper = + new InterpreterWrapper(std::move(model), std::move(error_reporter), + std::move(resolver), std::move(interpreter)); + return wrapper; +} + InterpreterWrapper::InterpreterWrapper( std::unique_ptr model, - std::unique_ptr error_reporter) + std::unique_ptr error_reporter, + std::unique_ptr resolver, + std::unique_ptr interpreter) : model_(std::move(model)), error_reporter_(std::move(error_reporter)), - resolver_(absl::make_unique()), - interpreter_(CreateInterpreter(model_.get(), *resolver_)) {} + resolver_(std::move(resolver)), + interpreter_(std::move(interpreter)) {} InterpreterWrapper::~InterpreterWrapper() {} @@ -421,11 +447,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( std::unique_ptr error_reporter(new PythonErrorReporter); std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get()); - if (!model) { - *error_msg = error_reporter->message(); - return nullptr; - } - return new InterpreterWrapper(std::move(model), std::move(error_reporter)); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( @@ -439,11 +462,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( std::unique_ptr model = tflite::FlatBufferModel::BuildFromBuffer(buf, length, error_reporter.get()); - if (!model) { - *error_msg = error_reporter->message(); - return nullptr; - } - return new InterpreterWrapper(std::move(model), std::move(error_reporter)); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } PyObject* InterpreterWrapper::ResetVariableTensorsToZero() { diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 556ec7117a179ad9fd880b295c4969bab0f4a7a7..641dd93db5b9df292e03e9704a218299f48b14fb 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -15,12 +15,15 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ #define TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ -// Place `` before to avoid build failures in macOS. -#include #include #include #include +// Place `` before to avoid build failures in macOS. +#include + +// The empty line above is on purpose as otherwise clang-format will +// automatically move before . #include // We forward declare TFLite classes here to avoid exposing them to SWIG. @@ -69,14 +72,28 @@ class InterpreterWrapper { PyObject* tensor(PyObject* base_object, int i); private: - InterpreterWrapper(std::unique_ptr model, - std::unique_ptr error_reporter); + // Helper function to construct an `InterpreterWrapper` object. + // It only returns InterpreterWrapper if it can construct an `Interpreter`. + // Otherwise it returns `nullptr`. + static InterpreterWrapper* CreateInterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::string* error_msg); + + InterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::unique_ptr resolver, + std::unique_ptr interpreter); // InterpreterWrapper is not copyable or assignable. We avoid the use of // InterpreterWrapper() = delete here for SWIG compatibility. InterpreterWrapper(); InterpreterWrapper(const InterpreterWrapper& rhs); + // The public functions which creates `InterpreterWrapper` should ensure all + // these member variables are initialized successfully. Otherwise it should + // report the error and return `nullptr`. const std::unique_ptr model_; const std::unique_ptr error_reporter_; const std::unique_ptr resolver_; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index a285bf991996fe8398ad77936004716bbacd077a..14f88b4c009e4f7cd913c2a27799ab418562fb1f 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -166,6 +166,9 @@ enum BuiltinOperator : byte { REDUCE_MAX = 82, PACK = 83, LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, } // Options for the builtin operators. @@ -230,6 +233,9 @@ union BuiltinOptions { FakeQuantOptions, PackOptions, LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, } enum Padding : byte { SAME, VALID } @@ -549,6 +555,16 @@ table PackOptions { table LogicalOrOptions { } +table OneHotOptions { + axis:int; +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 8c1d6d6a36a5081aa7ebb94c50ee4ec76a079b7f..3efa153e2cfd98dcac9352ff0ef4d8eb9bb6b66a 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -211,6 +211,15 @@ struct PackOptionsT; struct LogicalOrOptions; struct LogicalOrOptionsT; +struct OneHotOptions; +struct OneHotOptionsT; + +struct LogicalAndOptions; +struct LogicalAndOptionsT; + +struct LogicalNotOptions; +struct LogicalNotOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -361,11 +370,14 @@ enum BuiltinOperator { BuiltinOperator_REDUCE_MAX = 82, BuiltinOperator_PACK = 83, BuiltinOperator_LOGICAL_OR = 84, + BuiltinOperator_ONE_HOT = 85, + BuiltinOperator_LOGICAL_AND = 86, + BuiltinOperator_LOGICAL_NOT = 87, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_LOGICAL_OR + BuiltinOperator_MAX = BuiltinOperator_LOGICAL_NOT }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[87] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -450,7 +462,10 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] { BuiltinOperator_REDUCE_PROD, BuiltinOperator_REDUCE_MAX, BuiltinOperator_PACK, - BuiltinOperator_LOGICAL_OR + BuiltinOperator_LOGICAL_OR, + BuiltinOperator_ONE_HOT, + BuiltinOperator_LOGICAL_AND, + BuiltinOperator_LOGICAL_NOT }; return values; } @@ -542,6 +557,9 @@ inline const char **EnumNamesBuiltinOperator() { "REDUCE_MAX", "PACK", "LOGICAL_OR", + "ONE_HOT", + "LOGICAL_AND", + "LOGICAL_NOT", nullptr }; return names; @@ -614,11 +632,14 @@ enum BuiltinOptions { BuiltinOptions_FakeQuantOptions = 58, BuiltinOptions_PackOptions = 59, BuiltinOptions_LogicalOrOptions = 60, + BuiltinOptions_OneHotOptions = 61, + BuiltinOptions_LogicalAndOptions = 62, + BuiltinOptions_LogicalNotOptions = 63, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_LogicalOrOptions + BuiltinOptions_MAX = BuiltinOptions_LogicalNotOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[64] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -680,7 +701,10 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] { BuiltinOptions_ArgMinOptions, BuiltinOptions_FakeQuantOptions, BuiltinOptions_PackOptions, - BuiltinOptions_LogicalOrOptions + BuiltinOptions_LogicalOrOptions, + BuiltinOptions_OneHotOptions, + BuiltinOptions_LogicalAndOptions, + BuiltinOptions_LogicalNotOptions }; return values; } @@ -748,6 +772,9 @@ inline const char **EnumNamesBuiltinOptions() { "FakeQuantOptions", "PackOptions", "LogicalOrOptions", + "OneHotOptions", + "LogicalAndOptions", + "LogicalNotOptions", nullptr }; return names; @@ -1002,6 +1029,18 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1513,6 +1552,30 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_LogicalOrOptions ? reinterpret_cast(value) : nullptr; } + OneHotOptionsT *AsOneHotOptions() { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast(value) : nullptr; + } + const OneHotOptionsT *AsOneHotOptions() const { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast(value) : nullptr; + } + LogicalAndOptionsT *AsLogicalAndOptions() { + return type == BuiltinOptions_LogicalAndOptions ? + reinterpret_cast(value) : nullptr; + } + const LogicalAndOptionsT *AsLogicalAndOptions() const { + return type == BuiltinOptions_LogicalAndOptions ? + reinterpret_cast(value) : nullptr; + } + LogicalNotOptionsT *AsLogicalNotOptions() { + return type == BuiltinOptions_LogicalNotOptions ? + reinterpret_cast(value) : nullptr; + } + const LogicalNotOptionsT *AsLogicalNotOptions() const { + return type == BuiltinOptions_LogicalNotOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -5452,6 +5515,140 @@ inline flatbuffers::Offset CreateLogicalOrOptions( flatbuffers::Offset CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct OneHotOptionsT : public flatbuffers::NativeTable { + typedef OneHotOptions TableType; + int32_t axis; + OneHotOptionsT() + : axis(0) { + } +}; + +struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OneHotOptionsT NativeTableType; + enum { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); + } + OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OneHotOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(OneHotOptions::VT_AXIS, axis, 0); + } + explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOneHotOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { + OneHotOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +flatbuffers::Offset CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalAndOptionsT : public flatbuffers::NativeTable { + typedef LogicalAndOptions TableType; + LogicalAndOptionsT() { + } +}; + +struct LogicalAndOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogicalAndOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalAndOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalAndOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalAndOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogicalAndOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogicalAndOptionsBuilder &operator=(const LogicalAndOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogicalAndOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogicalAndOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogicalAndOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalNotOptionsT : public flatbuffers::NativeTable { + typedef LogicalNotOptions TableType; + LogicalNotOptionsT() { + } +}; + +struct LogicalNotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogicalNotOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalNotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalNotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalNotOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogicalNotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogicalNotOptionsBuilder &operator=(const LogicalNotOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogicalNotOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogicalNotOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -5765,6 +5962,15 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const { return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast(builtin_options()) : nullptr; } + const OneHotOptions *builtin_options_as_OneHotOptions() const { + return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast(builtin_options()) : nullptr; + } + const LogicalAndOptions *builtin_options_as_LogicalAndOptions() const { + return builtin_options_type() == BuiltinOptions_LogicalAndOptions ? static_cast(builtin_options()) : nullptr; + } + const LogicalNotOptions *builtin_options_as_LogicalNotOptions() const { + return builtin_options_type() == BuiltinOptions_LogicalNotOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -6036,6 +6242,18 @@ template<> inline const LogicalOrOptions *Operator::builtin_options_as inline const OneHotOptions *Operator::builtin_options_as() const { + return builtin_options_as_OneHotOptions(); +} + +template<> inline const LogicalAndOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalAndOptions(); +} + +template<> inline const LogicalNotOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalNotOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -8151,6 +8369,78 @@ inline flatbuffers::Offset CreateLogicalOrOptions(flatbuffers: _fbb); } +inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OneHotOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = axis(); _o->axis = _e; }; +} + +inline flatbuffers::Offset OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOneHotOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _axis = _o->axis; + return tflite::CreateOneHotOptions( + _fbb, + _axis); +} + +inline LogicalAndOptionsT *LogicalAndOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogicalAndOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogicalAndOptions::UnPackTo(LogicalAndOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogicalAndOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalAndOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogicalAndOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalAndOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalAndOptions( + _fbb); +} + +inline LogicalNotOptionsT *LogicalNotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogicalNotOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogicalNotOptions::UnPackTo(LogicalNotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogicalNotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalNotOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalNotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalNotOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -8580,6 +8870,18 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -8838,6 +9140,18 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -9084,6 +9398,18 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(value); + return CreateOneHotOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalAndOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalNotOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -9330,6 +9656,18 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new LogicalOrOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_OneHotOptions: { + value = new OneHotOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalAndOptions: { + value = new LogicalAndOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalNotOptions: { + value = new LogicalNotOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -9637,6 +9975,21 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc index 24593d2a673d5423be693800c03984e690a7e129..cd0f1f7c17a50f6ce61fa2033e5d13580399f5cf 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/simple_memory_arena.h" +#include #include #include #include diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 4c37bcb3c98cf3f1e0c4b5ffa3843e9f461b7157..a788d41ba7b370cd0e84c343202f1dca090180f3 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -112,7 +112,6 @@ cc_library( cc_test( name = "message_test", srcs = ["message_test.cc"], - tags = ["no_oss"], deps = [ ":message", "@com_google_googletest//:gtest_main", @@ -132,7 +131,6 @@ cc_test( name = "split_test", size = "small", srcs = ["split_test.cc"], - tags = ["no_oss"], deps = [ ":split", "@com_google_googletest//:gtest_main", @@ -142,13 +140,13 @@ cc_test( cc_library( name = "join", hdrs = ["join.h"], + deps = ["//tensorflow/contrib/lite:string"], ) cc_test( name = "join_test", size = "small", srcs = ["join_test.cc"], - tags = ["no_oss"], deps = [ ":join", "@com_google_googletest//:gtest_main", @@ -174,7 +172,6 @@ cc_test( srcs = ["tflite_driver_test.cc"], data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], tags = [ - "no_oss", "tflite_not_portable_android", "tflite_not_portable_ios", ], @@ -196,7 +193,6 @@ cc_library( cc_test( name = "tokenize_test", srcs = ["tokenize_test.cc"], - tags = ["no_oss"], deps = [ ":tokenize", "@com_google_googletest//:gtest_main", @@ -214,12 +210,15 @@ cc_library( cc_library( name = "util", hdrs = ["util.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string", + ], ) cc_test( name = "test_runner_test", srcs = ["test_runner_test.cc"], - tags = ["no_oss"], deps = [ ":test_runner", "@com_google_googletest//:gtest_main", @@ -275,6 +274,7 @@ cc_library( ":join", ":split", ":tf_driver", + "//tensorflow/contrib/lite:string", "//tensorflow/core:framework", ], ) @@ -341,7 +341,7 @@ tf_cc_test( ], tags = [ "no_cuda_on_cpu_tap", - "no_oss", + "no_oss", # needs test data "tflite_not_portable", ], deps = [ diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 41ece94237a0ff05c4c441d657ca0df813b69ea9..52ef0d5b86524d605b2f5d6dbae98d4c343ad6a0 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -90,8 +90,6 @@ TEST_INPUT_DEPTH = 3 # matching the expression will be considered due to the corresponding bug. KNOWN_BUGS = { # TOCO doesn't support scalars as input. - r"relu.*input_shape=\[\]": "67587484", - r"sigmoid.*input_shape=\[\]": "67645668", # Concat doesn't work with a single input tensor r"concat.*num_tensors=1": "67378344", # Transposition in MatMul is not fully supported. @@ -228,7 +226,9 @@ _TF_TYPE_INFO = { tf.float16: (np.float16, "FLOAT"), tf.int32: (np.int32, "INT32"), tf.uint8: (np.uint8, "QUANTIZED_UINT8"), + tf.int16: (np.int16, "QUANTIZED_INT16"), tf.int64: (np.int64, "INT64"), + tf.bool: (np.bool, "BOOL"), } @@ -240,9 +240,12 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16): value = (max_value-min_value)*np.random.random_sample(shape)+min_value - elif dtype in (tf.int32, tf.uint8, tf.int64): + elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value+1, shape) - return value.astype(dtype) + elif dtype == tf.bool: + value = np.random.choice([True, False], size=shape) + return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype( + dtype) def create_scalar_data(dtype, min_value=-100, max_value=100): @@ -253,7 +256,7 @@ def create_scalar_data(dtype, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16): value = (max_value - min_value) * np.random.random() + min_value - elif dtype in (tf.int32, tf.uint8, tf.int64): + elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value + 1) return np.array(value, dtype=dtype) @@ -479,7 +482,7 @@ def make_zip_of_tests(zip_path, else report_lib.FAILED) report["toco_log"] = toco_log - if FLAGS.save_graphdefs: + if True or FLAGS.save_graphdefs: archive.writestr(label + ".pbtxt", text_format.MessageToString(graph_def), zipfile.ZIP_DEFLATED) @@ -681,12 +684,20 @@ def make_relu6_tests(zip_path): def make_prelu_tests(zip_path): """Make a set of tests to do PReLU.""" - test_parameters = [{ - # The canonical case for image processing is having a 4D `input` (NHWC) - # and `shared_axes`=[1, 2], so the alpha parameter is per channel. - "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]], - "shared_axes": [[1, 2], [1]], - }] + test_parameters = [ + { + # The canonical case for image processing is having a 4D `input` + # (NHWC)and `shared_axes`=[1, 2], so the alpha parameter is per + # channel. + "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]], + "shared_axes": [[1, 2], [1]], + }, + { + # 2D-3D example. Share the 2nd axis. + "input_shape": [[20, 20], [20, 20, 20]], + "shared_axes": [[1]], + } + ] def build_graph(parameters): """Build the graph for the test case.""" @@ -734,21 +745,22 @@ def make_constant_tests(zip_path): test_parameters = [{ "dtype": [tf.float32, tf.int32], - "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], + "input_shape": [[], [1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], }] def build_graph(parameters): - # Since Toco & Tflite can't have a single constant op in the entire graph, - # this test adds a zero tensor with a constant op tensor. - input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", - shape=parameters["input_shape"]) - out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1 - return [input1], [out] + dummy_input = tf.placeholder( + dtype=parameters["dtype"], + name="input1", + shape=parameters["input_shape"]) + out = tf.constant( + create_tensor_data(parameters["dtype"], parameters["input_shape"])) + return [dummy_input], [out] def build_inputs(parameters, sess, inputs, outputs): - input1 = np.zeros(parameters["input_shape"], - dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) - return [input1], sess.run(outputs, feed_dict={inputs[0]: input1}) + dummy_input = np.zeros( + parameters["input_shape"], dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) + return [dummy_input], sess.run(outputs, feed_dict={inputs[0]: dummy_input}) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -809,11 +821,13 @@ def make_binary_op_tests(zip_path, binary_operator): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_reduce_tests(reduce_op): +def make_reduce_tests(reduce_op, min_value=-10, max_value=10): """Make a set of tests to do reduce operation. Args: reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`. + min_value: min value for created tensor data. + max_value: max value for created tensor data. Returns: a function representing the true generator with `reduce_op_in` curried. @@ -876,10 +890,12 @@ def make_reduce_tests(reduce_op): def build_inputs(parameters, sess, inputs, outputs): values = [ - create_tensor_data(parameters["input_dtype"], - parameters["input_shape"], - min_value=-10, - max_value=10)] + create_tensor_data( + parameters["input_dtype"], + parameters["input_shape"], + min_value=min_value, + max_value=max_value) + ] if not parameters["const_axis"]: values.append(np.array(parameters["axis"])) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) @@ -901,7 +917,8 @@ def make_sum_tests(zip_path): def make_reduce_prod_tests(zip_path): """Make a set of tests to do prod.""" - return make_reduce_tests(tf.reduce_prod)(zip_path) + # set min max value to be -2, 2 to avoid overflow. + return make_reduce_tests(tf.reduce_prod, -2, 2)(zip_path) def make_reduce_max_tests(zip_path): @@ -1340,6 +1357,7 @@ def make_concat_tests(zip_path): "base_shape": [[1, 3, 4, 3], [3, 4]], "num_tensors": [1, 2, 3, 4, 5, 6], "axis": [0, 1, 2, 3, -3, -2, -1], + "type": [tf.float32, tf.uint8, tf.int32, tf.int64], }] def get_shape(parameters, delta): @@ -1355,7 +1373,8 @@ def make_concat_tests(zip_path): def build_graph(parameters): all_tensors = [] for n in range(0, parameters["num_tensors"]): - input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n), + input_tensor = tf.placeholder(dtype=parameters["type"], + name=("input%d" % n), shape=get_shape(parameters, n)) all_tensors.append(input_tensor) out = tf.concat(all_tensors, parameters["axis"]) @@ -1364,8 +1383,8 @@ def make_concat_tests(zip_path): def build_inputs(parameters, sess, inputs, outputs): all_values = [] for n in range(0, parameters["num_tensors"]): - input_values = create_tensor_data(np.float32, - get_shape(parameters, n)) + input_values = create_tensor_data( + parameters["type"], get_shape(parameters, n)) all_values.append(input_values) return all_values, sess.run( outputs, feed_dict=dict(zip(inputs, all_values))) @@ -1608,6 +1627,11 @@ def make_reshape_tests(zip_path): "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], "constant_shape": [True, False], + }, { + "dtype": [tf.float32], + "input_shape": [[1]], + "output_shape": [[]], + "constant_shape": [True, False], }] def build_graph(parameters): @@ -1649,7 +1673,7 @@ def make_shape_tests(zip_path): }] def build_graph(parameters): - """Build the topk op testing graph.""" + """Build the shape op testing graph.""" # Note that we intentionally leave out the shape from the input placeholder # to prevent the Shape operation from being optimized out during conversion. input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") @@ -1665,6 +1689,65 @@ def make_shape_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_one_hot_tests(zip_path): + """Make a set of tests to do one_hot.""" + + test_parameters = [{ + "indices_type": [tf.int32, tf.int64], + "indices_shape": [[3], [4, 4], [1, 5], [5, 1]], + "axis": [0, 1], + "dtype": [tf.int32, tf.int64, tf.float32], + "provide_optional_inputs": [True, False], + }] + + def build_graph(parameters): + indices = tf.placeholder( + dtype=parameters["indices_type"], + name="indices", + shape=parameters["indices_shape"]) + depth = tf.placeholder(dtype=tf.int32, name="depth", shape=()) + + if not parameters["provide_optional_inputs"]: + out = tf.one_hot(indices=indices, depth=depth) + return [indices, depth], [out] + + on_value = tf.placeholder( + dtype=parameters["dtype"], name="on_value", shape=()) + off_value = tf.placeholder( + dtype=parameters["dtype"], name="off_value", shape=()) + out = tf.one_hot( + indices=indices, + depth=depth, + on_value=on_value, + off_value=off_value, + axis=parameters["axis"], + dtype=parameters["dtype"]) + return [indices, depth, on_value, off_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = [ + create_tensor_data( + parameters["indices_type"], + shape=parameters["indices_shape"], + min_value=-1, + max_value=10), + create_tensor_data(tf.int32, shape=None, min_value=1, max_value=10), + ] + + if parameters["provide_optional_inputs"]: + input_values.append( + create_tensor_data( + parameters["dtype"], shape=None, min_value=1, max_value=10)) + input_values.append( + create_tensor_data( + parameters["dtype"], shape=None, min_value=-1, max_value=0)) + + return input_values, sess.run( + outputs, feed_dict=dict(zip(inputs, input_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_resize_bilinear_tests(zip_path): """Make a set of tests to do resize_bilinear.""" @@ -2238,6 +2321,7 @@ def make_topk_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32], "input_shape": [[10], [5, 20]], + "input_k": [None, 1, 3], }] def build_graph(parameters): @@ -2246,15 +2330,23 @@ def make_topk_tests(zip_path): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) - k = tf.constant(3, name="k") + if parameters["input_k"] is not None: + k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[]) + else: + k = tf.constant(3, name="k") out = tf.nn.top_k(input_value, k) - return [input_value], [out[1]] + return [input_value, k], [out[1]] def build_inputs(parameters, sess, inputs, outputs): input_value = create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_value]))) + if parameters["input_k"] is not None: + k = np.array(parameters["input_k"], dtype=np.int32) + return [input_value, k], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, k]))) + else: + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -2918,6 +3010,57 @@ def make_pack_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def _make_logical_tests(op): + """Make a set of tests to do logical operations.""" + + def logical(zip_path): + """Generate examples.""" + test_parameters = [{ + "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the logical testing graph.""" + input_value1 = tf.placeholder( + dtype=tf.bool, name="input1", shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=tf.bool, name="input2", shape=parameters["input_shape_pair"][1]) + out = op(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return logical + + +def make_logical_or_tests(zip_path): + """Make a set of tests to do logical_or.""" + return _make_logical_tests(tf.logical_or)(zip_path) + + +def make_logical_and_tests(zip_path): + """Make a set of tests to do logical_and.""" + return _make_logical_tests(tf.logical_and)(zip_path) + + +def make_logical_xor_tests(zip_path): + """Make a set of tests to do logical_xor. + + Test logical_not as well. + """ + return _make_logical_tests(tf.logical_xor)(zip_path) + + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc index c1092e4d25567f0374e3cd5a27bde32419d3db19..f29c188e6c2c55bdb13d257c70e23c2943abfa4a 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.cc +++ b/tensorflow/contrib/lite/testing/generate_testspec.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/contrib/lite/testing/generate_testspec.h" #include "tensorflow/contrib/lite/testing/join.h" #include "tensorflow/contrib/lite/testing/split.h" @@ -88,13 +90,13 @@ bool GenerateTestSpecFromTensorflowModel( TfDriver runner(input_layer, input_layer_type, input_layer_shape, output_layer); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } runner.LoadModel(tensorflow_model_path); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } @@ -118,14 +120,14 @@ bool GenerateTestSpecFromTensorflowModel( for (int j = 0; j < input_values.size(); j++) { runner.SetInput(j, input_values[j]); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } } runner.Invoke(); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } @@ -137,7 +139,7 @@ bool GenerateTestSpecFromTensorflowModel( for (int j = 0; j < output_layer.size(); j++) { stream << " output: \"" << runner.ReadOutput(j) << "\"\n"; if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } } diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/contrib/lite/testing/generate_testspec.h index bfaf5e7ec89bbdd85b68a7dc45d7686e143e5d3d..b3d0db31c01a8cb1b8f34ff6dbb00c77de29b131 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.h +++ b/tensorflow/contrib/lite/testing/generate_testspec.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/string.h" + namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 770092e12c537c170b34536a51a9d0fbf9e1bf6b..e475f256c01c95755aabb0550153ff3c225aeb3b 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -86,9 +86,6 @@ std::map kBrokenTests = { // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, - // PRelu only supports 4D input with (1, 1, channels) 3D alpha now. - {R"(^\/prelu.*shared_axes=\[1\])", "75975192"}, - // No support for axis!=0 in GatherV2. {R"(^\/gather.*axis=1)", "76910444"}, @@ -226,7 +223,8 @@ TEST_P(OpsTest, RunZipTests) { string message = test_driver.GetErrorMessage(); if (bug_number.empty()) { if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) { - EXPECT_EQ(message, string("Failed to invoke interpreter")) << message; + EXPECT_EQ(message, string("Failed to invoke NNAPI interpreter")) + << message; } else { EXPECT_TRUE(result) << message; } diff --git a/tensorflow/contrib/lite/testing/join.h b/tensorflow/contrib/lite/testing/join.h index 1edee01cf97da3c53be1895e667b005551ac2991..4be19ad7569c3333b6647b91adbc6e77ff088f10 100644 --- a/tensorflow/contrib/lite/testing/join.h +++ b/tensorflow/contrib/lite/testing/join.h @@ -17,7 +17,8 @@ limitations under the License. #include #include -#include + +#include "tensorflow/contrib/lite/string.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h index 96ab6be54e528334f9e4a8cc259e44f99878fefb..fac7d01aab4b1e4c251213041eb4b823cd7d66aa 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -90,7 +90,7 @@ class TestRunner { // Invalidate the test runner, preventing it from executing any further. void Invalidate(const string& error_message) { - cerr << error_message << std::endl; + std::cerr << error_message << std::endl; error_message_ = error_message; } bool IsValid() const { return error_message_.empty(); } diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc index 3b27f6f3da92ce80c3830feb7c6af095e7c48e9c..ec435ca60d959a11a9392b6fbab99b0561f50942 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.cc +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -28,8 +28,8 @@ namespace { tensorflow::Tensor CreateTensor(const tensorflow::DataType type, const std::vector& dim) { - tensorflow::TensorShape shape{gtl::ArraySlice{ - reinterpret_cast(dim.data()), dim.size()}}; + tensorflow::TensorShape shape{tensorflow::gtl::ArraySlice{ + reinterpret_cast(dim.data()), dim.size()}}; return {type, shape}; } @@ -179,7 +179,7 @@ void TfDriver::Invoke() { auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()}, output_names_, {}, &output_tensors_); if (!status.ok()) { - Invalidate("Failed to invoke interpreter"); + Invalidate("Failed to run input data on graph"); } } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h index 7a57e8d3fba29cd106eb038992bb5ed12bb457ae..695c2a3de6c5d7c74a943134f0c97390710ef1e7 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ +#include + #include "tensorflow/contrib/lite/testing/split.h" #include "tensorflow/contrib/lite/testing/tflite_diff_util.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h index 6d20aec141c7c3a3e48af290edb169c6fd7254cf..8aa639157b8b68061f9ee8c3483959a79cb5794e 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/contrib/lite/testing/util.h @@ -15,8 +15,39 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#include + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string.h" + namespace tflite { +// An ErrorReporter that collects error message in a string, in addition +// to printing to stderr. +class TestErrorReporter : public ErrorReporter { + public: + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + fprintf(stderr, "%s", buffer); + error_messages_ += buffer; + num_calls_++; + return size; + } + + void Reset() { + num_calls_ = 0; + error_messages_.clear(); + } + + int num_calls() const { return num_calls_; } + const string& error_messages() const { return error_messages_; } + + private: + int num_calls_ = 0; + string error_messages_; +}; + inline void LogToStderr() { #ifdef PLATFORM_GOOGLE FLAGS_logtostderr = true; diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index c88079717ddc9bf39850762dffe711f0d2832d38..7243e584e9c26ce64bfa95e63842cda2087c655c 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -11,6 +11,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", + "tf_copts", ) tf_proto_library_cc( @@ -305,7 +306,7 @@ cc_library( "tensorflow_util.h", "toco_tooling.h", ], - copts = select({ + copts = tf_copts() + select({ "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"], "//conditions:default": [], }), diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 6877fb237c0514a972589ac0301647104f5ed7ed..30525efd2391bb63afd7035b8134e5858add45f2 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -167,7 +167,7 @@ NodeProperties GetPropertiesForArray(const Model& model, node_properties.label += "]"; int buffer_size = 0; - if (IsValid(array.shape())) { + if (IsNonEmpty(array.shape())) { buffer_size = RequiredBufferSizeForShape(array.shape()); node_properties.log2_buffer_size = std::log2(static_cast(buffer_size)); diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index b79bb300f0495e84c18d25c2063f58be3fc4fbe7..02671f0408f55726df730dbe0fe9a4f936d22632 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -664,13 +664,25 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, void ConvertMulOperator(const Model& model, const MulOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); - add_op->set_op("Mul"); - add_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* mul_op = tensorflow_graph->add_node(); + mul_op->set_op("Mul"); + mul_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *add_op->add_input() = src_op.inputs[0]; - *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type( + *mul_op->add_input() = src_op.inputs[0]; + *mul_op->add_input() = src_op.inputs[1]; + (*mul_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + +void ConvertDivOperator(const Model& model, const DivOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* div_op = tensorflow_graph->add_node(); + div_op->set_op("Div"); + div_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *div_op->add_input() = src_op.inputs[0]; + *div_op->add_input() = src_op.inputs[1]; + (*div_op->mutable_attr())["T"].set_type( GetTensorFlowDataType(model, src_op.outputs[0])); } @@ -1316,6 +1328,20 @@ void ConvertResizeBilinearOperator(const Model& model, (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners); } +void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node(); + onehot_op->set_op("OneHot"); + onehot_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + for (const auto& input : src_op.inputs) { + *onehot_op->add_input() = input; + } + (*onehot_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); + (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis); +} + namespace { // TODO(aselle): Remove when available in absl absl::string_view FindLongestCommonPrefix(absl::string_view a, @@ -1911,6 +1937,36 @@ void ConvertLogicalNotOperator(const Model& model, *logical_op->add_input() = src_op.inputs[0]; } +void ConvertLogicalOrOperator(const Model& model, + const LogicalOrOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node(); + logical_or_op->set_op(op_name); + logical_or_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *logical_or_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*logical_or_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertCTCBeamSearchDecoderOperator( + const Model& model, const CTCBeamSearchDecoderOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op(op_name); + op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *op->add_input() = src_op.inputs[i]; + } + (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width); + (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths); + (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1946,6 +2002,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kMul) { ConvertMulOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kDiv) { + ConvertDivOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kRelu) { ConvertReluOperator(model, static_cast(src_op), tensorflow_graph); @@ -2158,6 +2217,17 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertLogicalNotOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kOneHot) { + ConvertOneHotOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalOr) { + ConvertLogicalOrOperator(model, + static_cast(src_op), + "LogicalOr", tensorflow_graph); + } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) { + ConvertCTCBeamSearchDecoderOperator( + model, static_cast(src_op), + "CTCBeamSearchDecoder", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index 75642bbc37be6b3140e5b79a463ca70b5786d772..c13fc0de7502a9edc80dc399354708a5b1b96b02 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -181,7 +181,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, // future without worrying. static constexpr int kMinDistanceBetweenBadValues = 16; if (distance < kMinDistanceBetweenBadValues) { - if (allow_nudging_weights()) { + if (allow_nudging_weights() || has_default_ranges_flag()) { buffer_data[i] = 1; changed = true; continue; @@ -200,6 +200,15 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, } if (changed) { + if (has_default_ranges_flag()) { + std::cerr + << "Since the specified values of --default_ranges_min and " + "--default_ranges_max result in values incompatible with TFLite's " + "fast int8 kernels, " + "--allow_nudging_weights_to_use_fast_gemm_kernel " + "has been enabled. This may affect the accuracy of the model." + << std::endl; + } AddMessageF("Tweaked weights values for %s", LogName(op)); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index b7634e28c6a1e509d2b68b9f514f86a26c233f5d..8d9a4c4700e12ac1a187038a0a5efc1b033d4e57 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -262,8 +262,12 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { bool allow_nudging_weights() const { return allow_nudging_weights_; } void set_allow_nudging_weights(bool val) { allow_nudging_weights_ = val; } + bool has_default_ranges_flag() const { return has_default_ranges_flag_; } + void set_has_default_ranges_flag(bool val) { has_default_ranges_flag_ = val; } + private: bool allow_nudging_weights_ = false; + bool has_default_ranges_flag_ = false; }; #undef DECLARE_GRAPH_TRANSFORMATION diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 2f1bb8f0ad6374243e5a094701eef54cd086151a..d26c3b2878b8499fcbabc5448de9ec045eb07879 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -371,12 +371,26 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { case OperatorType::kStridedSlice: case OperatorType::kSqueeze: case OperatorType::kReshape: + case OperatorType::kExpandDims: case OperatorType::kPad: case OperatorType::kGather: case OperatorType::kTranspose: case OperatorType::kMean: changed = HardcodeMinMaxFromFirstInput(model, op); break; + case OperatorType::kSum: + // reduce_sum is expected to change the output range. Hence + // a fake_quant op is necessary in the output to minimize error. However + // in special circumstances like when computing expected value using + // reduce_sum the input range and the output range matches. Hence the + // below code would act as a fallback. If a fake_quant node is observed in + // the output that takes precendence over the hard coding logic below. + changed = HardcodeMinMaxFromFirstInput(model, op); + if (changed) { + LOG(WARNING) << "Using the input range for output in reduce_sum op." + << "This could have an impact on your model accuracy."; + } + break; case OperatorType::kSelect: changed = HardcodeMinMaxForSelect(model, op); break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 9848d55c8307d3704e1f3f419a158fc844ec69f3..c8310161cb33bcc7137e8b163ea6469698ed2fd7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -65,6 +65,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { case OperatorType::kAny: case OperatorType::kLogicalAnd: case OperatorType::kLogicalNot: + case OperatorType::kLogicalOr: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; @@ -141,7 +142,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 2); CHECK(model->GetArray(op->inputs[1]).data_type == ArrayDataType::kInt32); - model->GetArray(op->outputs[0]).data_type = model->GetArray(op->inputs[0]).data_type; + model->GetArray(op->outputs[0]).data_type = + model->GetArray(op->inputs[0]).data_type; model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32; break; } @@ -154,8 +156,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { return false; } for (int i = 0; i < op->outputs.size(); ++i) { - auto output = op->outputs[i]; - auto data_type = unsupported_op->output_data_types[i]; + const string& output = op->outputs[i]; + const ArrayDataType data_type = unsupported_op->output_data_types[i]; model->GetArray(output).data_type = data_type; } break; @@ -201,6 +203,30 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type); break; } + case OperatorType::kOneHot: { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + const ArrayDataType on_value_type = + model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type; + const ArrayDataType off_value_type = + model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT]) + .data_type; + CHECK(on_value_type == off_value_type); + model->GetArray(op->outputs[0]).data_type = on_value_type; + break; + } + case OperatorType::kCTCBeamSearchDecoder: { + CHECK_EQ(op->inputs.size(), 2); + // All outputs (sparse tensors) are int32s (although tf uses int64s) + // except the last one (log probabilities) is float. + const int output_size = op->outputs.size(); + for (int i = 0; i < output_size - 1; ++i) { + model->GetArray(op->outputs[i]).data_type = ArrayDataType::kInt32; + } + model->GetArray(op->outputs[output_size - 1]).data_type = + ArrayDataType::kFloat; + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 62ed5c46e99c561764e853c0419f2046c8086dba..91e290439ae4bfd491c8201b02b161fe2caf2f8d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1082,27 +1082,23 @@ void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { } // Yield until input dims have been resolved. - if (!input_values.has_shape()) { + if (!input_values.has_shape() || !input_k.has_shape()) { return; } - const auto& input_values_shape = input_values.shape(); - auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); - auto output_values_dims = output_values.mutable_shape()->mutable_dims(); - for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { - output_indexes_dims->push_back(input_values_shape.dims(dim)); - output_values_dims->push_back(input_values_shape.dims(dim)); - } // If the value is initialized, we can specify the last dimension, otherwise // unknown. if (input_k.buffer) { + const auto& input_values_shape = input_values.shape(); + auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); + auto output_values_dims = output_values.mutable_shape()->mutable_dims(); + for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { + output_indexes_dims->push_back(input_values_shape.dims(dim)); + output_values_dims->push_back(input_values_shape.dims(dim)); + } const int32_t k_value = input_k.GetBuffer().data[0]; output_indexes_dims->push_back(k_value); output_values_dims->push_back(k_value); - - } else { - output_indexes_dims->push_back(0); - output_values_dims->push_back(0); } } @@ -1578,6 +1574,61 @@ void ProcessAnyOperator(Model* model, AnyOperator* op) { } } +void ProcessOneHotOperator(Model* model, OneHotOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + // Yield until indices dims have been resolved. + const auto& indices_array = + model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]); + if (!indices_array.has_shape()) { + return; + } + + // Yield until depth is constant and dims have been resolved. + if (!IsConstantParameterArray(*model, + op->inputs[OneHotOperator::DEPTH_INPUT])) { + return; + } + const auto& depth_array = + model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]); + if (!depth_array.has_shape()) { + return; + } + + CHECK(depth_array.data_type == ArrayDataType::kInt32) + << "Depth array must be int32."; + CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1) + << "Depth array must be scalar."; + + const int depth = depth_array.GetBuffer().data[0]; + CHECK_GE(depth, 0) << "Depth must be non-negative."; + + const int indices_dims = indices_array.shape().dimensions_count(); + const int output_dims = indices_dims + 1; + const int axis = op->axis == -1 ? indices_dims : op->axis; + CHECK_GE(axis, 0) << "Resolved axis must be non-negative."; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(output_dims); + for (int i = 0; i < output_dims; ++i) { + int dim = 0; + if (i < axis) { + dim = indices_array.shape().dims(i); + } else if (i == axis) { + dim = depth; + } else { + dim = indices_array.shape().dims(i - 1); + } + (*mutable_dims)[i] = dim; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1618,6 +1669,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kSin: case OperatorType::kLogicalAnd: case OperatorType::kLogicalNot: + case OperatorType::kLogicalOr: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: @@ -1786,8 +1838,19 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessArgMinMaxOperator( model, static_cast(op)); break; - case OperatorType::kUnsupported: + case OperatorType::kUnsupported: { + const auto* unsupported_op = + static_cast(op); + // Attribute can be not specified, ignore it. + if (unsupported_op->output_shapes.size() < op->outputs.size()) { + return false; + } + for (int i = 0; i < op->outputs.size(); ++i) { + const string& output = op->outputs[i]; + model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i)); + } break; + } case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast(op)); break; @@ -1814,6 +1877,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kAny: ProcessAnyOperator(model, static_cast(op)); break; + case OperatorType::kOneHot: + ProcessOneHotOperator(model, static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index f6ce3b3ecb2cc06708287804bf34aa152d668f8c..8d22ae2eb1356b8c9c9430c517acddfc971b9f57 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -50,7 +50,7 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kSqueeze || type == OperatorType::kPad || type == OperatorType::kPadV2 || type == OperatorType::kReshape || type == OperatorType::kTanh || type == OperatorType::kMul || - type == OperatorType::kBatchToSpaceND || + type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum || type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || @@ -61,9 +61,20 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kGreaterEqual || type == OperatorType::kLess || type == OperatorType::kLessEqual || type == OperatorType::kSelect || type == OperatorType::kArgMax || type == OperatorType::kRelu || - type == OperatorType::kRelu1 || type == OperatorType::kRelu6; + type == OperatorType::kRelu1 || type == OperatorType::kRelu6 || + type == OperatorType::kShape || type == OperatorType::kExpandDims; } +// The quantized op allows output arrays of type float using +// the attribute support_output_type_float_in_quantized_op +bool SupportOutputTypeFloatInQuantizedOp(const Operator& op) { + auto type = op.type; + if (type == OperatorType::kUnsupported) { + auto* unsupported = static_cast(&op); + return unsupported->support_output_type_float_in_quantized_op; + } + return false; +} const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { auto& array = model->GetArray(array_name); // Normally we should have a MinMax recorded on this Array, @@ -584,61 +595,67 @@ bool Quantize::Run(Model* model, std::size_t op_index) { } // Quantize outputs, add Dequantize ops as needed on the outputs side - for (std::size_t output_index = 0; output_index < op.outputs.size(); - output_index++) { - ArrayDataType quantized_data_type; - QuantizationParams quantization_params; - if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, - &quantized_data_type, - &quantization_params)) { - changed = true; - const auto& output = op.outputs[output_index]; - auto& output_array = model->GetArray(output); - - // Fix up the min/max information on the output array to match the chosen - // quantization parameters. - CHECK(output_array.minmax) - << "Output array named " << output << " lacks minmax"; - auto& output_minmax = output_array.GetMinMax(); - FixMinMaxPostQuantization(this, quantized_data_type, quantization_params, - &output_minmax); - - QuantizeArray(this, model, output, quantized_data_type, - quantization_params); - - const auto& dequantized_output = - AvailableArrayName(*model, output + "_dequantized"); - auto& dequantized_output_array = - model->GetOrCreateArray(dequantized_output); - dequantized_output_array.data_type = ArrayDataType::kFloat; - dequantized_output_array.final_data_type = output_array.data_type; - auto& dequantized_output_minmax = - dequantized_output_array.GetOrCreateMinMax(); - dequantized_output_minmax.min = output_minmax.min; - dequantized_output_minmax.max = output_minmax.max; - for (const auto& other_op : model->operators) { - for (auto& other_op_input : other_op->inputs) { - if (other_op_input == output) { - other_op_input = dequantized_output; + if (SupportOutputTypeFloatInQuantizedOp(op)) { + LOG(WARNING) + << HelpfulOperatorTypeName(op) << " is a quantized op" + << "but it has a model flag that sets the output arrays to float."; + } else { + for (std::size_t output_index = 0; output_index < op.outputs.size(); + output_index++) { + QuantizationParams quantization_params; + ArrayDataType quantized_data_type; + if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& output = op.outputs[output_index]; + auto& output_array = model->GetArray(output); + + // Fix up the min/max information on the output array to match the + // chosen quantization parameters. + CHECK(output_array.minmax) + << "Output array named " << output << " lacks minmax"; + auto& output_minmax = output_array.GetMinMax(); + FixMinMaxPostQuantization(this, quantized_data_type, + quantization_params, &output_minmax); + + QuantizeArray(this, model, output, quantized_data_type, + quantization_params); + + const auto& dequantized_output = + AvailableArrayName(*model, output + "_dequantized"); + auto& dequantized_output_array = + model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + dequantized_output_array.final_data_type = output_array.data_type; + auto& dequantized_output_minmax = + dequantized_output_array.GetOrCreateMinMax(); + dequantized_output_minmax.min = output_minmax.min; + dequantized_output_minmax.max = output_minmax.max; + for (const auto& other_op : model->operators) { + for (auto& other_op_input : other_op->inputs) { + if (other_op_input == output) { + other_op_input = dequantized_output; + } } } - } - auto* dequantize_op = new DequantizeOperator; - dequantize_op->inputs = {output}; - dequantize_op->outputs = {dequantized_output}; - for (int i = 0; i < model->flags.output_arrays_size(); i++) { - if (model->flags.output_arrays(i) == output) { - // TODO(b/78013785): never rename output arrays. - AddMessageF( - "Renaming output array %d after inserting dequant op %s: %s -> " - "%s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantized_output); - model->flags.set_output_arrays(i, dequantized_output); + auto* dequantize_op = new DequantizeOperator; + dequantize_op->inputs = {output}; + dequantize_op->outputs = {dequantized_output}; + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == output) { + // TODO(b/78013785): never rename output arrays. + AddMessageF( + "Renaming output array %d after inserting dequant op %s: %s -> " + "%s", + i, LogName(*dequantize_op), model->flags.output_arrays(i), + dequantized_output); + model->flags.set_output_arrays(i, dequantized_output); + } } + const auto op_it = FindOp(*model, &op); + model->operators.emplace(op_it + 1, dequantize_op); } - const auto op_it = FindOp(*model, &op); - model->operators.emplace(op_it + 1, dequantize_op); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 058f314b338aeeab94cb11fb8c1163427b559d3e..d395d7a6a0862d93fd4f52bb8b8d8d3ea7f8dc1e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -26,14 +26,17 @@ limitations under the License. namespace toco { template -void GetBoundsForQuantizedDataType(double* min, double* max) { +void GetBoundsForQuantizedDataType(float* min, float* max) { using limits = std::numeric_limits>; *min = limits::min(); *max = limits::max(); } void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type, - double* min, double* max) { + float* min, float* max) { + // It is important for matching accuracy between TF training and TFLite + // inference, that the min and max values are float to match TF's + // FakeQuantWithMinMaxVarsFunctor. switch (quantized_data_type) { case ArrayDataType::kUint8: return GetBoundsForQuantizedDataType(min, max); @@ -109,22 +112,22 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { QuantizationParams qparams; ChooseQuantizationParamsForArrayAndQuantizedDataType( output_array, quantized_data_type, &qparams); - double quantized_min, quantized_max; + float quantized_min, quantized_max; GetBoundsForQuantizedDataType(quantized_data_type, &quantized_min, &quantized_max); if (fakequant_op->narrow_range) { quantized_min++; } - for (int i = 0; i < size; i++) { - const double src_val = input_buffer.data[i]; - const double unclamped_quantized_val = - std::round(qparams.zero_point + src_val / qparams.scale); - const double quantized_val = std::min( - quantized_max, std::max(quantized_min, unclamped_quantized_val)); - const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); - output_buffer.data[i] = dst_val; - } + // It is important for matching accuracy between TF training and TFLite + // inference, that the following variables are float to match TF's + // FakeQuantWithMinMaxVarsFunctor. + const float scale = qparams.scale; + const float nudged_min = (quantized_min - qparams.zero_point) * scale; + const float nudged_max = (quantized_max - qparams.zero_point) * scale; + tflite::FakeQuantizeArray(scale, nudged_min, nudged_max, + input_buffer.data.data(), output_buffer.data.data(), + size); if (IsDiscardableArray(*model, fakequant_op->inputs[0]) && CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 032c86394521a6f9dfafcbf5fdd3372568243bc6..b7fffbce2223a71ac1e16ec1ce18ba9f610cc2ac 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -215,7 +215,7 @@ tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -253,7 +253,7 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -290,7 +290,7 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -326,7 +326,7 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -363,7 +363,7 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -409,7 +409,7 @@ tensorflow::Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -1045,6 +1045,13 @@ tensorflow::Status ConvertSimpleOperator( tensorflow::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { + // Names of special attributes in TF graph that are used by Toco. + static constexpr char kAttrOutputQuantized[] = "_output_quantized"; + static constexpr char kAttrOutputTypes[] = "_output_types"; + static constexpr char kAttrOutputShapes[] = "_output_shapes"; + static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] = + "_support_output_type_float_in_quantized_op"; + LOG(INFO) << "Converting unsupported operation: " << node.op(); auto* op = new TensorFlowUnsupportedOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1055,11 +1062,17 @@ tensorflow::Status ConvertUnsupportedOperator( op->tensorflow_op = node.op(); node.SerializeToString(&op->tensorflow_node_def); model->operators.emplace_back(op); - if (HasAttr(node, "_output_quantized")) { - op->quantized = GetBoolAttr(node, "_output_quantized"); + // Parse if the op supports quantization + if (HasAttr(node, kAttrOutputQuantized)) { + op->quantized = GetBoolAttr(node, kAttrOutputQuantized); + } + // Parse if the quantized op allows output arrays of type float + if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) { + op->support_output_type_float_in_quantized_op = + GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp); } - if (HasAttr(node, "_output_types")) { - const auto& output_types = GetListAttr(node, "_output_types"); + if (HasAttr(node, kAttrOutputTypes)) { + const auto& output_types = GetListAttr(node, kAttrOutputTypes); for (int i = 0; i < output_types.type_size(); ++i) { op->output_data_types.push_back(ConvertDataType(output_types.type(i))); } @@ -1067,6 +1080,19 @@ tensorflow::Status ConvertUnsupportedOperator( const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); } + if (HasAttr(node, kAttrOutputShapes)) { + const auto& output_shapes = GetListAttr(node, kAttrOutputShapes); + Shape output_shape; + for (int i = 0; i < output_shapes.shape_size(); ++i) { + const auto status = + ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr, + &output_shape); + if (!status.ok()) { + return status; + } + op->output_shapes.push_back(output_shape); + } + } return tensorflow::Status::OK(); } @@ -1197,11 +1223,10 @@ tensorflow::Status ConvertGatherOperator( return tensorflow::Status::OK(); } -template +template tensorflow::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), op_name); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; @@ -1219,6 +1244,20 @@ tensorflow::Status ConvertArgMinMaxOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertArgMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ArgMax"); + return ConvertArgMinMaxOperator(node, tf_import_flags, model); +} + +tensorflow::Status ConvertArgMinOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ArgMin"); + return ConvertArgMinMaxOperator(node, tf_import_flags, model); +} + tensorflow::Status ConvertResizeBilinearOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1815,6 +1854,55 @@ tensorflow::Status ConvertSparseToDenseOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertOneHotOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "OneHot"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); + + const auto dtype = GetDataTypeAttr(node, "T"); + // TODO(b/111744875): Support DT_UINT8 and quantization. + CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT || + dtype == DT_BOOL); + + auto op = absl::make_unique(); + op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertCTCBeamSearchDecoderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + + auto* op = new CTCBeamSearchDecoderOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + + op->beam_width = + HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1; + op->top_paths = + HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1; + op->merge_repeated = HasAttr(node, "merge_repeated") + ? GetBoolAttr(node, "merge_repeated") + : true; + + // There are top_paths + 1 outputs. + op->outputs.push_back(node.name()); // Implicit :0. + for (int i = 0; i < op->top_paths; ++i) { + op->outputs.push_back(node.name() + ":" + std::to_string(i + 1)); + } + model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + } // namespace namespace internal { @@ -1824,17 +1912,14 @@ using ConverterType = tensorflow::Status (*)( Model* model); using ConverterMapType = std::unordered_map; -constexpr char kArgMax[] = "ArgMax"; -constexpr char kArgMin[] = "ArgMin"; - ConverterMapType GetTensorFlowNodeConverterMap() { return std::unordered_map({ {"Add", ConvertSimpleOperator}, {"AddN", ConvertSimpleOperator}, {"All", ConvertSimpleOperator}, {"Any", ConvertAnyOperator}, - {"ArgMax", ConvertArgMinMaxOperator}, - {"ArgMin", ConvertArgMinMaxOperator}, + {"ArgMax", ConvertArgMaxOperator}, + {"ArgMin", ConvertArgMinOperator}, {"Assert", ConvertSimpleOperator}, {"AvgPool", ConvertAvgPoolOperator}, {"BatchMatMul", ConvertBatchMatMulOperator}, @@ -1849,6 +1934,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Const", ConvertConstOperator}, {"Conv2D", ConvertConvOperator}, {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator}, {"DepthToSpace", ConvertDepthToSpaceOperator}, {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, {"Div", ConvertSimpleOperator}, @@ -1875,9 +1961,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Less", ConvertSimpleOperator}, {"LessEqual", ConvertSimpleOperator}, {"Log", ConvertSimpleOperator}, - {"LogSoftmax", ConvertSimpleOperator}, {"LogicalAnd", ConvertSimpleOperator}, + {"LogicalOr", ConvertSimpleOperator}, {"LogicalNot", ConvertSimpleOperator}, + {"LogSoftmax", ConvertSimpleOperator}, {"MatMul", ConvertMatMulOperator}, {"Max", ConvertReduceOperator}, {"MaxPool", ConvertMaxPoolOperator}, @@ -1891,6 +1978,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge}, {"NoOp", ConvertNoOpOperator}, {"NotEqual", ConvertSimpleOperator}, + {"OneHot", ConvertOneHotOperator}, {"Pack", ConvertPackOperator}, {"Pad", ConvertSimpleOperator}, {"PadV2", ConvertSimpleOperator}, diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index d629787939d6930ae2d7969dc7efece545bcdcb4..412e14c4ada3280dafcd2fcfa59e2908dd785f9f 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -64,6 +64,7 @@ enum class OperatorType : uint8 { kMaxPool, kFakeQuant, kMul, + kOneHot, kRandomUniform, kRange, kRank, @@ -146,6 +147,8 @@ enum class OperatorType : uint8 { kAny, kLogicalAnd, kLogicalNot, + kLogicalOr, + kCTCBeamSearchDecoder, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -292,6 +295,46 @@ struct Buffer : GenericBuffer { std::vector> data; }; +class Shape { + public: + // For Shape, we stick to half-way encapsulation for now: + // we hide the raw dims_ member, but expose it raw by accessors + // because from some brainstorming, it's not at all easy to + // anticipate which flavor of more hermetic encapsulation would + // actually buy us future-proof-ness without being needlessly + // cumbersome. + Shape() {} + Shape(std::initializer_list dim_list) : dims_(dim_list) {} + + void ReplaceDims(std::initializer_list dim_list) { + dims_ = std::vector(dim_list); + } + + const std::vector& dims() const { return dims_; } + std::vector* mutable_dims() { return &dims_; } + const int dimensions_count() const { return dims_.size(); } + + // We still have that one convenience accessor to avoid + // the awkward double bracket issue: shape.dims()[i]. + int dims(int i) const { + // Always check for out-of-bounds accesses, even in optimized builds where + // standard assertions are disabled. Out-of-bounds access here is a common + // occurrence. + CHECK_GE(i, 0); + CHECK_GT(dims_.size(), i); + return dims_[i]; + } + + bool operator==(const Shape& comp) const { + return (this->dims_ == comp.dims()); + } + + bool operator!=(const Shape& comp) const { return !((*this) == comp); } + + private: + std::vector dims_; +}; + // Base class for all operator classes. struct Operator { // Non-default-constructible: only OperatorType-specific subclass @@ -396,6 +439,28 @@ struct ConvOperator : Operator { int dilation_height_factor = 1; }; +// CTCBeamSearchDecoder operator: +// +// Inputs: +// inputs[0]: required: the logits. +// inputs[1]: required: sequence length. +// inputs[2]: optional: beam width. +// inputs[3]: optional: top paths. +// inputs[4]: optional: merge repeated. +// +// Outputs: +// outputs[0]: deocoded. +// outputs[1]: log probability. +// +// TensorFlow equivalent: CTCBeamSearchDecoder +struct CTCBeamSearchDecoderOperator : Operator { + CTCBeamSearchDecoderOperator() + : Operator(OperatorType::kCTCBeamSearchDecoder) {} + int beam_width; + int top_paths; + bool merge_repeated = true; +}; + // Depthwise-separable convolution operator. // // Inputs: @@ -1467,8 +1532,13 @@ struct TensorFlowUnsupportedOperator : Operator { string tensorflow_node_def; // A boolean indicating if the unsupported op should be treated as quantized. bool quantized = false; + // A boolean indicating if the unsupported op output should allow float values + // in quantized mode. + bool support_output_type_float_in_quantized_op = false; // Output data types std::vector output_data_types; + // Output shapes. + std::vector output_shapes; }; // Softmax activation function. @@ -1726,6 +1796,38 @@ struct LogicalNotOperator : Operator { LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {} }; +// OneHot operator: +// +// Inputs: +// Inputs[0]: required: indices. +// Inputs[1]: required: depth. +// Inputs[2]: required: on_value. +// Inputs[3]: required: off_value. +// +// TensorFlow equivalent: OneHot. +struct OneHotOperator : Operator { + enum Inputs { + INDICES_INPUT = 0, + DEPTH_INPUT = 1, + ON_VALUE_INPUT = 2, + OFF_VALUE_INPUT = 3, + }; + + OneHotOperator() : Operator(OperatorType::kOneHot) {} + int axis = -1; +}; + +// LogicalOr operator: +// +// Inputs: +// Inputs[0]: required: A Bool tensor. +// Inputs[1]: required: A Bool tensor. +// +// TensorFlow equivalent: LogicalOr. +struct LogicalOrOperator : Operator { + LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are @@ -1739,46 +1841,6 @@ inline bool operator<(const Alloc& a, const Alloc& b) { return a.start < b.start; } -class Shape { - public: - // For Shape, we stick to half-way encapsulation for now: - // we hide the raw dims_ member, but expose it raw by accessors - // because from some brainstorming, it's not at all easy to - // anticipate which flavor of more hermetic encapsulation would - // actually buy us future-proof-ness without being needlessly - // cumbersome. - Shape() {} - Shape(std::initializer_list dim_list) : dims_(dim_list) {} - - void ReplaceDims(std::initializer_list dim_list) { - dims_ = std::vector(dim_list); - } - - const std::vector& dims() const { return dims_; } - std::vector* mutable_dims() { return &dims_; } - const int dimensions_count() const { return dims_.size(); } - - // We still have that one convenience accessor to avoid - // the awkward double bracket issue: shape.dims()[i]. - int dims(int i) const { - // Always check for out-of-bounds accesses, even in optimized builds where - // standard assertions are disabled. Out-of-bounds access here is a common - // occurrence. - CHECK_GE(i, 0); - CHECK_GT(dims_.size(), i); - return dims_[i]; - } - - bool operator==(const Shape& comp) const { - return (this->dims_ == comp.dims()); - } - - bool operator!=(const Shape& comp) const { return !((*this) == comp); } - - private: - std::vector dims_; -}; - // Array represents an array (either a constant parameter array or an // activations array) in a Model. struct Array { @@ -2009,7 +2071,7 @@ class Model { std::size_t transient_data_size = 0; // For code-generation only: required alignment of the transient_data buffer std::size_t transient_data_alignment = 0; - // Arithmatic operations performed in the model. + // Arithmetic operations performed in the model. int64 ops_count = 0; private: diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index d1fdbcb8e9131e1d65fa32ca0395bbc17b2014e7..a95937ba0f4f66fedfab6c1528c8dc4e417297d0 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -262,7 +262,7 @@ TEST_F(VersionedOpExportTest, Export) { EXPECT_EQ(1, (*operators)[1]->opcode_index()); } -// TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators. +// TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators. } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 4b2ef756cce52dd803a93708969f217a9d420548..9ff89e9a653173fa1edc691b17b60f709af0a435 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1053,6 +1053,44 @@ class Shape int GetVersion(const Operator& op) const override { return 1; } }; +class OneHot : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateOneHotOptions(*builder, op.axis); + } + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class CTCBeamSearchDecoder + : public CustomOperator { + public: + using CustomOperator::CustomOperator; + + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("beam_width", op.beam_width); + fbb->Int("top_paths", op.top_paths); + fbb->Bool("merge_repeated", op.merge_repeated); + } + + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->beam_width = m["beam_width"].AsInt32(); + op->top_paths = m["top_paths"].AsInt32(); + op->merge_repeated = m["merge_repeated"].AsBool(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -1162,6 +1200,12 @@ class TensorFlowUnsupported : public BaseOperator { break; case flexbuffers::TYPE_BOOL: (*attr)[key].set_b(value.AsBool()); + if (string(key) == "_output_quantized") { + op->quantized = value.AsBool(); + } + if (string(key) == "_support_output_type_float_in_quantized_op") { + op->support_output_type_float_in_quantized_op = value.AsBool(); + } break; case flexbuffers::TYPE_VECTOR_INT: { auto* list = (*attr)[key].mutable_list(); @@ -1278,10 +1322,14 @@ std::vector> BuildOperatorList() { OperatorType::kFakeQuant)); ops.emplace_back( new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); + ops.emplace_back( + new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); // Custom Operators. ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); + ops.emplace_back(new CTCBeamSearchDecoder( + "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported)); @@ -1331,6 +1379,12 @@ std::vector> BuildOperatorList() { ops.emplace_back( new SimpleOperator("SLICE", OperatorType::kSlice)); ops.emplace_back(new SimpleOperator("POW", OperatorType::kPow)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_OR", OperatorType::kLogicalOr)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_AND", OperatorType::kLogicalAnd)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_NOT", OperatorType::kLogicalNot)); // Element-wise operator ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 44de6fbf6413257025263bfdea21bd260e37897c..fc854461b4e816e12e12590479501b6542258fef 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -127,6 +127,12 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("SQRT", OperatorType::kSqrt); CheckSimpleOperator("RSQRT", OperatorType::kRsqrt); CheckSimpleOperator("POW", OperatorType::kPow); + CheckSimpleOperator("LOGICAL_OR", + OperatorType::kLogicalOr); + CheckSimpleOperator("LOGICAL_AND", + OperatorType::kLogicalAnd); + CheckSimpleOperator("LOGICAL_NOT", + OperatorType::kLogicalNot); } TEST_F(OperatorTest, BuiltinAdd) { @@ -462,6 +468,28 @@ TEST_F(OperatorTest, BuiltinPack) { EXPECT_EQ(op.axis, output_toco_op->axis); } +TEST_F(OperatorTest, BuiltinOneHot) { + OneHotOperator op; + op.axis = 2; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("ONE_HOT", OperatorType::kOneHot), op); + EXPECT_EQ(op.axis, output_toco_op->axis); +} + +TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) { + CTCBeamSearchDecoderOperator op; + op.beam_width = 3; + op.top_paths = 2; + op.merge_repeated = false; + std::unique_ptr output_toco_op = + SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER", + OperatorType::kCTCBeamSearchDecoder), + op); + EXPECT_EQ(op.beam_width, output_toco_op->beam_width); + EXPECT_EQ(op.top_paths, output_toco_op->top_paths); + EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index de76fd4032d24eff8a6c2fd0c16a911b9c00186b..14168fa33f77a75706a52f00ddfa6b1120d90626 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -38,7 +38,8 @@ void CopyToBuffer(const Cord& src, char* dest) { src.CopyToArray(dest); } } // namespace port } // namespace toco -#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && !defined(__ANDROID__) +#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && \ + !defined(__ANDROID__) && !defined(_WIN32) // Wrap Google file operations. @@ -115,9 +116,12 @@ string JoinPath(const string& a, const string& b) { } // namespace port } // namespace toco -#else // (__APPLE__ || __ANDROID__) +#else // !PLATFORM_GOOGLE || __APPLE__ || __ANDROID__ || _WIN32 #include +#if defined(_WIN32) +#include // for _close, _open, _read +#endif #include #include #include @@ -130,6 +134,19 @@ string JoinPath(const string& a, const string& b) { namespace toco { namespace port { +#if defined(_WIN32) +#define close _close +#define open _open +#define read _read +#define O_RDONLY _O_RDONLY +#define O_CREAT _O_CREAT +#define O_WRONLY _O_WRONLY +// Windows does not support the same set of file permissions as other platforms. +constexpr int kFileCreateMode = _S_IREAD | _S_IWRITE; +#else +constexpr int kFileCreateMode = 0664; +#endif // _WIN32 + static bool port_initialized = false; void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) { @@ -209,7 +226,7 @@ tensorflow::Status GetContents(const string& path, string* output, tensorflow::Status SetContents(const string& filename, const string& contents, const file::Options& options) { - int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); + int fd = open(filename.c_str(), O_WRONLY | O_CREAT, kFileCreateMode); if (fd == -1) { return tensorflow::errors::Internal("can't open() for write"); } @@ -243,4 +260,4 @@ string JoinPath(const string& base, const string& filename) { } // namespace port } // namespace toco -#endif // (__APPLE || __ANDROID__) +#endif // !PLATFORM_GOOGLE || __APPLE || __ANDROID__ || _WIN32 diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index aa7f6996eb8a166184cac04ff818c66fcdc34703..fcd3cbab07c06737f43d822e5b16f7c188f56b1a 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -309,8 +309,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) { // HardcodeMinMax to move changes through the graph as we make changes. auto propagate_default_min_max = absl::make_unique(); - if (toco_flags.has_default_ranges_min() && - toco_flags.has_default_ranges_max()) { + bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() && + toco_flags.has_default_ranges_max()); + if (has_default_ranges_flag) { propagate_default_min_max->DefineTypeRange( ArrayDataType::kUint8, toco_flags.default_ranges_min(), toco_flags.default_ranges_max()); @@ -335,6 +336,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) { new EnsureUint8WeightsSafeForFastInt8Kernels; ensure_safe_for_int8_kernels->set_allow_nudging_weights( toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel()); + ensure_safe_for_int8_kernels->set_has_default_ranges_flag( + has_default_ranges_flag); RunGraphTransformations(model, "quantization graph transformations", { new RemoveTrivialQuantizedActivationFunc, diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 98e416b76ed5d37cceecf8de56b4f173fd6a8d72..2ad27198119b4a8150a7381c047a4edb51aebfe6 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -356,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) + HANDLE_OPERATORTYPENAME_CASE(OneHot) HANDLE_OPERATORTYPENAME_CASE(Pack) HANDLE_OPERATORTYPENAME_CASE(Pad) HANDLE_OPERATORTYPENAME_CASE(PadV2) @@ -402,6 +403,8 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Any) HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) HANDLE_OPERATORTYPENAME_CASE(LogicalNot) + HANDLE_OPERATORTYPENAME_CASE(LogicalOr) + HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -599,14 +602,33 @@ void UnextendShape(Shape* shape, int new_shape_size) { shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); } -bool IsValid(const Shape& shape) { +// In general, zero-sized dimensions are disallowed, but there are exceptions, +// e.g., if the tensor data itself represents a scalar (rank 0) shape, its +// shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more +// strict, and is appropriate for ops and comparisons where an empty shape +// doesn't make sense. +template +void CheckValidShapeDimensions(const Dims& dims) { + if (dims.size() == 1 && dims[0] == 0) { + return; + } + for (const auto& dim : dims) { + CHECK_GE(dim, 1); + } +} + +void CheckValidShape(const Shape& shape) { + CheckValidShapeDimensions(shape.dims()); +} + +bool IsNonEmpty(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { if (shape.dims(i) < 1) return false; } return true; } -void CheckShapeDimensions(const Shape& shape) { +void CheckNonEmptyShapeDimensions(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i << ". shape = " << ShapeToString(shape); @@ -614,8 +636,8 @@ void CheckShapeDimensions(const Shape& shape) { } bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { - CheckShapeDimensions(shape0); - CheckShapeDimensions(shape1); + CheckNonEmptyShapeDimensions(shape0); + CheckNonEmptyShapeDimensions(shape1); const Shape* longer = &shape0; const Shape* shorter = &shape1; @@ -642,8 +664,8 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { } bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { - CheckShapeDimensions(shape0); - CheckShapeDimensions(shape1); + CheckNonEmptyShapeDimensions(shape0); + CheckNonEmptyShapeDimensions(shape1); const Shape* longer = &shape0; const Shape* shorter = &shape1; @@ -680,9 +702,9 @@ bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { } int RequiredBufferSizeForShape(const Shape& shape) { + CheckValidShape(shape); int max_offset = 1; for (const auto& dim : shape.dims()) { - CHECK_GE(dim, 1); max_offset *= dim; } return max_offset; @@ -943,13 +965,7 @@ void CheckEachArray(const Model& model) { // shape. CHECK(array->has_shape()); // Constant buffer should has a valid shape. - bool is_scalar = - array->shape().dimensions_count() == 1 && array->shape().dims(0) == 0; - if (!is_scalar) { - for (int d : array->shape().dims()) { - CHECK_GE(d, 1); - } - } + CheckValidShape(array->shape()); // The shape flat-size should agree with the buffer length. CHECK_EQ(array->buffer->Length(), RequiredBufferSizeForShape(array->shape())); @@ -1541,8 +1557,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { if (!input_array.has_shape()) { if (input_array_proto.has_shape()) { auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); + CheckValidShapeDimensions(input_array_proto.shape().dims()); for (auto dim : input_array_proto.shape().dims()) { - CHECK_GE(dim, 1); input_array_dims.push_back(dim); } } @@ -1617,11 +1633,12 @@ void CheckIsReadyForQuantization(const Model& model) { << "Array " << input << ", which is an input to the " << HelpfulOperatorTypeName(*op) << " operator producing the output " << "array " << op->outputs[0] << ", is lacking min/max data, " - << "which is necessary for quantization. Either target a " - << "non-quantized output format, or change the input graph to " - << "contain min/max information, or pass --default_ranges_min= and " - << "--default_ranges_max= if you do not care about the accuracy of " - << "results."; + << "which is necessary for quantization. If accuracy matters, either " + << "target a non-quantized output format, or run quantized training " + << "with your model from a floating point checkpoint to change the " + << "input graph to contain min/max information. If you don't care " + << "about accuracy, you can pass --default_ranges_min= and " + << "--default_ranges_max= for easy experimentation."; } } } diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 5dbfa54fa0369676dce638aec171b409a468da9f..b99e6111fe92be178b5ff8b83477f1ce10c20926 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -115,10 +115,9 @@ void ExtendShape(Shape* shape, int new_shape_size); // TODO(b/36075966): Clean up when dims superseded by array shape. void UnextendShape(Shape* shape, int new_shape_size); -// Checks that all dimensions of 'shape' are at least 1. -bool IsValid(const Shape& shape); -// Same as above, but reports error using CHECK. -void CheckShapeDimensions(const Shape& shape); +// Checks that all dimensions of 'shape' are at least 1. Note that scalars, +// lacking dimensions, satisfy this condition and are considered non-empty. +bool IsNonEmpty(const Shape& shape); // Given two shapes with potentially different dimensionality and dimension // arrays d0 and d1. Without loss of generality, assume that shape0 may have diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py index e07f899e4d8c249cb03d4251a722df0614007fed..597dede63b0c089da21f4b0ede065189d8bbe1d8 100644 --- a/tensorflow/contrib/lite/tools/visualize.py +++ b/tensorflow/contrib/lite/tools/visualize.py @@ -334,7 +334,7 @@ def CreateHtmlFile(tflite_input, html_output): for key, mapping in toplevel_stuff: if not mapping: mapping = lambda x: x - html += "\n" % (key, mapping(data[key])) + html += "\n" % (key, mapping(data.get(key))) html += "
- Mobilenet_1.0_224(float) + Mobilenet_1.0_224(float) Pixel 2 166.5 ms (2.6 ms)
- Mobilenet_1.0_224 (quant) + Mobilenet_1.0_224 (quant) Pixel 2 69.5 ms (0.9 ms)
- Mobilenet_1.0_224(float) + Mobilenet_1.0_224(float) iPhone 8 32.2 ms (0.8 ms)
- Mobilenet_1.0_224 (quant) + Mobilenet_1.0_224 (quant) iPhone 8 24.4 ms (0.8 ms)
%s%s
%s%s
\n" # Spec on what keys to display diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 48953e2e3843ff92744514d28bd725cc0d72f3a8..448ae6d22e65fcd9129e27e6321d3081abf7d1ac 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -30,7 +30,11 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow +# 1.10 branch does not work. `make distclean` fails and blocks the build +# process. For now we're hardcoding to the version which is used by +# TensorFlow 1.9. +PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz" RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD index 54bd39afacbec07f054f61b72eda0a3654858aa7..16ddc38f5a5ba88485e18b136b2b1081b0e2ff0f 100644 --- a/tensorflow/contrib/model_pruning/BUILD +++ b/tensorflow/contrib/model_pruning/BUILD @@ -95,6 +95,22 @@ py_library( ], ) +py_library( + name = "strip_pruning_vars_lib", + srcs = ["python/strip_pruning_vars_lib.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pruning", + "//tensorflow/python:client", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_test( name = "pruning_utils_test", size = "small", @@ -129,6 +145,31 @@ py_test( ], ) +py_test( + name = "strip_pruning_vars_test", + size = "small", + srcs = ["python/strip_pruning_vars_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers", + ":pruning", + ":rnn_cells", + ":strip_pruning_vars_lib", + "//tensorflow/python:client_testlib", + ], +) + +py_binary( + name = "strip_pruning_vars", + srcs = ["python/strip_pruning_vars.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":strip_pruning_vars_lib", + "//tensorflow/python:platform", + ], +) + py_library( name = "init_py", srcs = ["__init__.py"], @@ -145,5 +186,6 @@ py_library( ":learning", ":pruning", ":rnn_cells", + ":strip_pruning_vars_lib", ], ) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 9143d082bf08fefa7aa522455eb3af911e636ae0..a5267fd90482287a65a4c38ae257a0af349523e8 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -4,7 +4,15 @@ This document describes the API that facilitates magnitude-based pruning of neural network's weight tensors. The API helps inject necessary tensorflow op into the training graph so the model can be pruned while it is being trained. -### Model creation +## Table of contents +1. [Model creation](#model-creation) +2. [Hyperparameters for pruning](#hyperparameters) + - [Block sparsity](#block-sparsity) +3. [Adding pruning ops to the training graph](#adding-pruning-ops) +4. [Removing pruning ops from trained model](#remove) +5. [Example](#example) + +### Model creation The first step involves adding mask and threshold variables to the layers that need to undergo pruning. The variable mask is the same shape as the layer's @@ -33,7 +41,7 @@ auxiliary variables built-in (see * [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154) -### Adding pruning ops to the training graph +### Pruning-related hyperparameters The pruning library allows for specification of the following hyper parameters: @@ -42,7 +50,7 @@ The pruning library allows for specification of the following hyper parameters: | name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope | | begin_pruning_step | integer | 0 | The global step at which to begin pruning | | end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | -| do_not_prune | list of strings | [""] | list of layers names that are not pruned | +| weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | | nbins | integer | 256 | Number of bins to use for histogram computation | @@ -64,7 +72,13 @@ is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$ is the sparsity_function_begin_step. In this equation, the sparsity_function_exponent is set to 3. -### Adding pruning ops to the training graph + +#### Block Sparsity + +For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). +The convolution layer tensors are always pruned used block dimensions of [1,1]. + +### Adding pruning ops to the training graph The final step involves adding ops to the training graph that monitor the distribution of the layer's weight magnitudes and determine the layer threshold, @@ -105,7 +119,19 @@ with tf.graph.as_default(): ``` Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work! -## Example: Pruning and training deep CNNs on the cifar10 dataset +### Removing pruning ops from the trained graph +Once the model is trained, it is necessary to remove the auxiliary variables (mask, threshold) and pruning ops added to the graph in the steps above. This can be accomplished using the `strip_pruning_vars` utility. + +This utility generates a binary GraphDef in which the variables have been converted to constants. In particular, the threshold variables are removed from the graph and the mask variable is fused with the corresponding weight tensor to produce a `masked_weight` tensor. This tensor is sparse, has the same size as the weight tensor, and the sparsity is as set by the `target_sparsity` or the `weight_sparsity_map` hyperparameters above. + +```shell +$ bazel build -c opt contrib/model_pruning:strip_pruning_vars +$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_dir=/path/to/checkpoints/ --output_node_names=graph_node1,graph_node2 --output_dir=/tmp --filename=pruning_stripped.pb +``` + +For now, it is assumed that the underlying hardware platform will provide mechanisms for compressing the sparse tensors and/or accelerating the sparse tensor computations. + +## Example: Pruning and training deep CNNs on the cifar10 dataset Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural network architecture, setting up inputs etc. The additional changes needed to @@ -121,7 +147,7 @@ incorporate pruning are captured in the following: To train the pruned version of cifar10: -```bash +```shell $ examples_dir=contrib/model_pruning/examples $ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval} $ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000 @@ -133,10 +159,14 @@ Eval: $ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once ``` -### Block Sparsity +Removing pruning nodes from the trained graph: -For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). -The convolution layer tensors are always pruned used block dimensions of [1,1]. +```shell +$ bazel build -c opt contrib/model_pruning:strip_pruning_vars +$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_path=/tmp/cifar10_train --output_node_names=softmax_linear/softmax_linear_2 --filename=cifar_pruned.pb +``` + +The generated GraphDef (cifar_pruned.pb) may be visualized using the [`import_pb_to_tensorboard`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/tools/import_pb_to_tensorboard.py) utility ## References diff --git a/tensorflow/contrib/model_pruning/__init__.py b/tensorflow/contrib/model_pruning/__init__.py index d32bedbcd6b63bc8e473a9e9d1c8e0753877e6f8..6eca54aaee186f5873a84ef2cb3ff3c7cfb42cd4 100644 --- a/tensorflow/contrib/model_pruning/__init__.py +++ b/tensorflow/contrib/model_pruning/__init__.py @@ -33,6 +33,9 @@ from tensorflow.contrib.model_pruning.python.pruning import get_thresholds from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity from tensorflow.contrib.model_pruning.python.pruning import get_weights from tensorflow.contrib.model_pruning.python.pruning import Pruning +from tensorflow.contrib.model_pruning.python.strip_pruning_vars_lib import graph_def_from_checkpoint +from tensorflow.contrib.model_pruning.python.strip_pruning_vars_lib import strip_pruning_vars_fn + # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented @@ -41,7 +44,8 @@ _allowed_symbols = [ 'masked_convolution', 'masked_conv2d', 'masked_fully_connected', 'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask', 'get_masked_weights', 'get_masks', 'get_pruning_hparams', 'get_thresholds', - 'get_weights', 'get_weight_sparsity', 'Pruning' + 'get_weights', 'get_weight_sparsity', 'Pruning', 'strip_pruning_vars_fn', + 'graph_def_from_checkpoint' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py index 466daf204a1ae86a7f37107342046305ea7249fc..d453e350f05c8e66df13c3861959980d69a564e8 100644 --- a/tensorflow/contrib/model_pruning/python/layers/layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/layers.py @@ -139,7 +139,7 @@ def masked_convolution(inputs, with "NC". num_outputs: Integer, the number of output filters. kernel_size: A sequence of N positive integers specifying the spatial - dimensions of of the filters. Can be a single integer to specify the same + dimensions of the filters. Can be a single integer to specify the same value for all spatial dimensions. stride: A sequence of N positive integers specifying the stride at which to compute output. Can be a single integer to specify the same value for all diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index da9d398cbc06299a33ab400cc9b4d780531211db..cd58526ed3620d4bd880cf36d806afac70c4bff7 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -152,8 +152,11 @@ def get_pruning_hparams(): end_pruning_step: integer the global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops - do_not_prune: list of strings - list of layers that are not pruned + weight_sparsity_map: list of strings + comma separed list of weight variable name:target sparsity pairs. + For layers/weights not in this list, sparsity as specified by the + target_sparsity hyperparameter is used. + Eg. [conv1:0.9,conv2/kernel:0.8] threshold_decay: float the decay factor to use for exponential decay of the thresholds pruning_frequency: integer @@ -200,7 +203,7 @@ def get_pruning_hparams(): name='model_pruning', begin_pruning_step=0, end_pruning_step=-1, - do_not_prune=[''], + weight_sparsity_map=[''], threshold_decay=0.9, pruning_frequency=10, nbins=256, @@ -234,6 +237,9 @@ class Pruning(object): # Pruning specification self._spec = spec if spec else get_pruning_hparams() + # Sanity check for pruning hparams + self._validate_spec() + # A tensorflow variable that tracks the sparsity function. # If not provided as input, the graph must already contain the global_step # variable before calling this constructor. @@ -256,6 +262,37 @@ class Pruning(object): # Block pooling function self._block_pooling_function = self._spec.block_pooling_function + # Mapping of weight names and target sparsity + self._weight_sparsity_map = self._get_weight_sparsity_map() + + def _validate_spec(self): + spec = self._spec + if spec.begin_pruning_step < 0: + raise ValueError('Illegal value for begin_pruning_step') + + if spec.begin_pruning_step >= spec.end_pruning_step: + if spec.end_pruning_step != -1: + raise ValueError( + 'Pruning must begin before it can end. begin_step=%d, end_step=%d.' + 'Set end_pruning_step to -1 if pruning is required till training' + 'stops' % (spec.begin_pruning_step, spec.end_pruning_step)) + + if spec.sparsity_function_begin_step < 0: + raise ValueError('Illegal value for sparsity_function_begin_step') + + if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step: + raise ValueError( + 'Sparsity function requires begin_step < end_step') + + if not 0.0 <= spec.threshold_decay < 1.0: + raise ValueError('threshold_decay must be in range [0,1)') + + if not 0.0 <= spec.initial_sparsity < 1.0: + raise ValueError('initial_sparsity must be in range [0,1)') + + if not 0.0 <= spec.target_sparsity < 1.0: + raise ValueError('target_sparsity must be in range [0,1)') + def _setup_global_step(self, global_step): graph_global_step = global_step if graph_global_step is None: @@ -270,11 +307,6 @@ class Pruning(object): target_sparsity = self._spec.target_sparsity exponent = self._spec.sparsity_function_exponent - if begin_step >= end_step: - raise ValueError( - 'Pruning must begin before it can end. begin_step=%d, end_step=%d' % - (begin_step, end_step)) - with ops.name_scope(self._spec.name): p = math_ops.minimum( 1.0, @@ -306,15 +338,36 @@ class Pruning(object): 'last_mask_update_step', dtype=dtypes.int32) return last_update_step - def _exists_in_do_not_prune_list(self, tensor_name): - do_not_prune_list = self._spec.do_not_prune - if not do_not_prune_list[0]: - return False - for layer_name in do_not_prune_list: - if tensor_name.find(layer_name) != -1: - return True - - return False + def _get_weight_sparsity_map(self): + """Return the map of weight_name:sparsity parsed from the hparams.""" + weight_sparsity_map = {} + val_list = self._spec.weight_sparsity_map + filtered_val_list = [l for l in val_list if l] + for val in filtered_val_list: + weight_name, sparsity = val.split(':') + if float(sparsity) >= 1.0: + raise ValueError('Weight sparsity can not exceed 1.0') + weight_sparsity_map[weight_name] = float(sparsity) + + return weight_sparsity_map + + def _get_sparsity(self, weight_name): + """Return target sparsity for the given layer/weight name.""" + target_sparsity = [ + sparsity for name, sparsity in self._weight_sparsity_map.items() + if weight_name.find(name) != -1 + ] + if not target_sparsity: + return self._sparsity + + if len(target_sparsity) > 1: + raise ValueError( + 'Multiple matches in weight_sparsity_map for weight %s' % weight_name) + # TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize + # to handle other cases as well. + return math_ops.mul( + self._sparsity, + math_ops.div(target_sparsity[0], self._spec.target_sparsity)) def _update_mask(self, weights, threshold): """Updates the mask for a given weight tensor. @@ -342,6 +395,8 @@ class Pruning(object): if self._sparsity is None: raise ValueError('Sparsity variable undefined') + sparsity = self._get_sparsity(weights.op.name) + with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(weights) max_value = math_ops.reduce_max(abs_weights) @@ -354,7 +409,7 @@ class Pruning(object): math_ops.div( math_ops.reduce_sum( math_ops.cast( - math_ops.less(norm_cdf, self._sparsity), dtypes.float32)), + math_ops.less(norm_cdf, sparsity), dtypes.float32)), float(self._spec.nbins)), max_value) smoothed_threshold = math_ops.add_n([ @@ -453,10 +508,6 @@ class Pruning(object): if is_partitioned: weight = weight.as_tensor() - if self._spec.do_not_prune: - if self._exists_in_do_not_prune_list(mask.name): - continue - new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold) self._assign_ops.append( pruning_utils.variable_assign(threshold, new_threshold)) @@ -507,22 +558,15 @@ class Pruning(object): no_update_op) def add_pruning_summaries(self): - """Adds summaries for this pruning spec. - - Args: none - - Returns: none - """ + """Adds summaries of weight sparsities and thresholds.""" with ops.name_scope(self._spec.name + '_summaries'): summary.scalar('sparsity', self._sparsity) summary.scalar('last_mask_update_step', self._last_update_step) masks = get_masks() thresholds = get_thresholds() for mask, threshold in zip(masks, thresholds): - if not self._exists_in_do_not_prune_list(mask.name): - summary.scalar(mask.op.name + '/sparsity', - nn_impl.zero_fraction(mask)) - summary.scalar(threshold.op.name + '/threshold', threshold) + summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask)) + summary.scalar(threshold.op.name + '/threshold', threshold) def print_hparams(self): logging.info(self._spec.to_json()) diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index f80b7c52c000f13b5ce98dd442ff21abfac37761..33c4ad58bd7f57422935fc839ddfc64d5e1f00f5 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -35,8 +35,8 @@ from tensorflow.python.training import training_util class PruningHParamsTest(test.TestCase): PARAM_LIST = [ "name=test", "threshold_decay=0.9", "pruning_frequency=10", - "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100", - "target_sparsity=0.9" + "sparsity_function_end_step=100", "target_sparsity=0.9", + "weight_sparsity_map=[conv1:0.8,conv2/kernel:0.8]" ] TEST_HPARAMS = ",".join(PARAM_LIST) @@ -55,9 +55,10 @@ class PruningHParamsTest(test.TestCase): self.assertEqual(p._spec.name, "test") self.assertAlmostEqual(p._spec.threshold_decay, 0.9) self.assertEqual(p._spec.pruning_frequency, 10) - self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"]) self.assertEqual(p._spec.sparsity_function_end_step, 100) self.assertAlmostEqual(p._spec.target_sparsity, 0.9) + self.assertEqual(p._weight_sparsity_map["conv1"], 0.8) + self.assertEqual(p._weight_sparsity_map["conv2/kernel"], 0.8) def testInitWithExternalSparsity(self): with self.test_session(): @@ -211,6 +212,37 @@ class PruningTest(test.TestCase): expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] self.assertAllEqual(expected_non_zero_count, non_zero_count) + def testWeightSpecificSparsity(self): + param_list = [ + "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", + "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]", + "threshold_decay=0.0" + ] + test_spec = ",".join(param_list) + pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) + + with variable_scope.variable_scope("layer1"): + w1 = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + _ = pruning.apply_mask(w1) + with variable_scope.variable_scope("layer2"): + w2 = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + _ = pruning.apply_mask(w2) + + p = pruning.Pruning(pruning_hparams) + mask_update_op = p.conditional_mask_update_op() + increment_global_step = state_ops.assign_add(self.global_step, 1) + + with self.test_session() as session: + variables.global_variables_initializer().run() + for _ in range(110): + session.run(mask_update_op) + session.run(increment_global_step) + + self.assertAllEqual( + session.run(pruning.get_weight_sparsity()), [0.5, 0.75]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py new file mode 100644 index 0000000000000000000000000000000000000000..3385103807f6dbdab2d27882c670a3ccf6a26e9d --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Removes the auxiliary variables and ops added by the pruning library. + +Usage: + +bazel build tensorflow/contrib/model_pruning:strip_pruning_vars && \ +bazel-bin/tensorflow/contrib/model_pruning/strip_pruning_vars \ +--checkpoint_dir=/tmp/model_ckpts \ +--output_node_names=softmax \ +--output_dir=/tmp \ +--filename=pruning_stripped.pb +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib +from tensorflow.python.framework import graph_io +from tensorflow.python.platform import app +from tensorflow.python.platform import tf_logging as logging + +FLAGS = None + + +def strip_pruning_vars(checkpoint_dir, output_node_names, output_dir, filename): + """Remove pruning-related auxiliary variables and ops from the graph. + + Accepts training checkpoints and produces a GraphDef in which the pruning vars + and ops have been removed. + + Args: + checkpoint_dir: Path to the checkpoints. + output_node_names: The name of the output nodes, comma separated. + output_dir: Directory where to write the graph. + filename: Output GraphDef file name. + + Returns: + None + + Raises: + ValueError: if output_nodes_names are not provided. + """ + if not output_node_names: + raise ValueError( + 'Need to specify atleast 1 output node through output_node_names flag') + output_node_names = output_node_names.replace(' ', '').split(',') + + initial_graph_def = strip_pruning_vars_lib.graph_def_from_checkpoint( + checkpoint_dir, output_node_names) + + final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( + initial_graph_def, output_node_names) + graph_io.write_graph(final_graph_def, output_dir, filename, as_text=False) + logging.info('\nFinal graph written to %s', os.path.join( + output_dir, filename)) + + +def main(unused_args): + return strip_pruning_vars(FLAGS.checkpoint_dir, FLAGS.output_node_names, + FLAGS.output_dir, FLAGS.filename) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--checkpoint_dir', type=str, default='', help='Path to the checkpoints.') + parser.add_argument( + '--output_node_names', + type=str, + default='', + help='The name of the output nodes, comma separated.') + parser.add_argument( + '--output_dir', + type=str, + default='/tmp', + help='Directory where to write the graph.') + parser.add_argument( + '--filename', + type=str, + default='pruning_stripped.pb', + help='Output \'GraphDef\' file name.') + + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4b10863f7c46235059f948fbbfcfcf83d3e15b --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py @@ -0,0 +1,142 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to remove pruning-related ops and variables from a GraphDef. +""" + +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib + + +def _node_name(tensor_name): + """Remove the trailing ':0' from the variable name.""" + if ':' not in tensor_name: + return tensor_name + + return tensor_name.split(':')[0] + + +def _tensor_name(node_name): + """Appends the :0 in the op name to get the canonical tensor name.""" + if ':' in node_name: + return node_name + + return node_name + ':0' + + +def _get_masked_weights(input_graph_def): + """Extracts masked_weights from the graph as a dict of {var_name:ndarray}.""" + input_graph = ops.Graph() + with input_graph.as_default(): + importer.import_graph_def(input_graph_def, name='') + + with session.Session(graph=input_graph) as sess: + masked_weights_dict = {} + for node in input_graph_def.node: + if 'masked_weight' in node.name: + masked_weight_val = sess.run( + sess.graph.get_tensor_by_name(_tensor_name(node.name))) + logging.info( + '%s has %d values, %1.2f%% zeros \n', node.name, + np.size(masked_weight_val), + 100 - float(100 * np.count_nonzero(masked_weight_val)) / + np.size(masked_weight_val)) + masked_weights_dict.update({node.name: masked_weight_val}) + return masked_weights_dict + + +def strip_pruning_vars_fn(input_graph_def, output_node_names): + """Removes mask variable from the graph. + + Replaces the masked_weight tensor with element-wise multiplication of mask + and the corresponding weight variable. + + Args: + input_graph_def: A GraphDef in which the variables have been converted to + constants. This is typically the output of + tf.graph_util.convert_variables_to_constant() + output_node_names: List of name strings for the result nodes of the graph + + Returns: + A GraphDef in which pruning-related variables have been removed + """ + masked_weights_dict = _get_masked_weights(input_graph_def) + pruned_graph_def = graph_pb2.GraphDef() + + # Replace masked_weight with a const op containing the + # result of tf.multiply(mask,weight) + for node in input_graph_def.node: + output_node = node_def_pb2.NodeDef() + if 'masked_weight' in node.name: + output_node.op = 'Const' + output_node.name = node.name + dtype = node.attr['T'] + data = masked_weights_dict[node.name] + output_node.attr['dtype'].CopyFrom(dtype) + output_node.attr['value'].CopyFrom( + attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data))) + + else: + output_node.CopyFrom(node) + pruned_graph_def.node.extend([output_node]) + + # Remove stranded nodes: mask and weights + return graph_util.extract_sub_graph(pruned_graph_def, output_node_names) + + +def graph_def_from_checkpoint(checkpoint_dir, output_node_names): + """Converts checkpoint data to GraphDef. + + Reads the latest checkpoint data and produces a GraphDef in which the + variables have been converted to constants. + + Args: + checkpoint_dir: Path to the checkpoints. + output_node_names: List of name strings for the result nodes of the graph. + + Returns: + A GraphDef from the latest checkpoint + + Raises: + ValueError: if no checkpoint is found + """ + checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir) + if checkpoint_path is None: + raise ValueError('Could not find a checkpoint at: {0}.' + .format(checkpoint_dir)) + + saver_for_restore = saver_lib.import_meta_graph( + checkpoint_path + '.meta', clear_devices=True) + with session.Session() as sess: + saver_for_restore.restore(sess, checkpoint_path) + graph_def = ops.get_default_graph().as_graph_def() + output_graph_def = graph_util.convert_variables_to_constants( + sess, graph_def, output_node_names) + + return output_graph_def diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py new file mode 100644 index 0000000000000000000000000000000000000000..255daa036099c0d3ef2dbc5eb37fdb0c31c71383 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py @@ -0,0 +1,232 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for strip_pruning_vars.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.contrib.model_pruning.python import pruning +from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib +from tensorflow.contrib.model_pruning.python.layers import layers +from tensorflow.contrib.model_pruning.python.layers import rnn_cells +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell as tf_rnn_cells +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + + +def _get_number_pruning_vars(graph_def): + number_vars = 0 + for node in graph_def.node: + if re.match(r"^.*(mask$)|(threshold$)", node.name): + number_vars += 1 + return number_vars + + +def _get_node_names(tensor_names): + return [ + strip_pruning_vars_lib._node_name(tensor_name) + for tensor_name in tensor_names + ] + + +class StripPruningVarsTest(test.TestCase): + + def setUp(self): + param_list = [ + "pruning_frequency=1", "begin_pruning_step=1", "end_pruning_step=10", + "nbins=2048", "threshold_decay=0.0" + ] + self.initial_graph = ops.Graph() + self.initial_graph_def = None + self.final_graph = ops.Graph() + self.final_graph_def = None + self.pruning_spec = ",".join(param_list) + with self.initial_graph.as_default(): + self.sparsity = variables.Variable(0.5, name="sparsity") + self.global_step = training_util.get_or_create_global_step() + self.increment_global_step = state_ops.assign_add(self.global_step, 1) + self.mask_update_op = None + + def _build_convolutional_model(self, number_of_layers): + # Create a graph with several conv2d layers + kernel_size = 3 + base_depth = 4 + depth_step = 7 + height, width = 7, 9 + with variable_scope.variable_scope("conv_model"): + input_tensor = array_ops.ones((8, height, width, base_depth)) + top_layer = input_tensor + for ix in range(number_of_layers): + top_layer = layers.masked_conv2d( + top_layer, + base_depth + (ix + 1) * depth_step, + kernel_size, + scope="Conv_" + str(ix)) + + return top_layer + + def _build_fully_connected_model(self, number_of_layers): + base_depth = 4 + depth_step = 7 + + input_tensor = array_ops.ones((8, base_depth)) + + top_layer = input_tensor + + with variable_scope.variable_scope("fc_model"): + for ix in range(number_of_layers): + top_layer = layers.masked_fully_connected( + top_layer, base_depth + (ix + 1) * depth_step) + + return top_layer + + def _build_lstm_model(self, number_of_layers): + batch_size = 8 + dim = 10 + inputs = variables.Variable(random_ops.random_normal([batch_size, dim])) + + def lstm_cell(): + return rnn_cells.MaskedBasicLSTMCell( + dim, forget_bias=0.0, state_is_tuple=True, reuse=False) + + cell = tf_rnn_cells.MultiRNNCell( + [lstm_cell() for _ in range(number_of_layers)], state_is_tuple=True) + + outputs = rnn.static_rnn( + cell, [inputs], + initial_state=cell.zero_state(batch_size, dtypes.float32)) + + return outputs + + def _prune_model(self, session): + pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec) + p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity) + self.mask_update_op = p.conditional_mask_update_op() + + variables.global_variables_initializer().run() + for _ in range(20): + session.run(self.mask_update_op) + session.run(self.increment_global_step) + + def _get_outputs(self, session, input_graph, tensors_list, graph_prefix=None): + outputs = [] + + for output_tensor in tensors_list: + if graph_prefix: + output_tensor = graph_prefix + "/" + output_tensor + outputs.append( + session.run(session.graph.get_tensor_by_name(output_tensor))) + + return outputs + + def _get_initial_outputs(self, output_tensor_names_list): + with self.test_session(graph=self.initial_graph) as sess1: + self._prune_model(sess1) + reference_outputs = self._get_outputs(sess1, self.initial_graph, + output_tensor_names_list) + + self.initial_graph_def = graph_util.convert_variables_to_constants( + sess1, sess1.graph.as_graph_def(), + _get_node_names(output_tensor_names_list)) + return reference_outputs + + def _get_final_outputs(self, output_tensor_names_list): + self.final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( + self.initial_graph_def, _get_node_names(output_tensor_names_list)) + _ = importer.import_graph_def(self.final_graph_def, name="final") + + with self.test_session(self.final_graph) as sess2: + final_outputs = self._get_outputs( + sess2, + self.final_graph, + output_tensor_names_list, + graph_prefix="final") + return final_outputs + + def _check_removal_of_pruning_vars(self, number_masked_layers): + self.assertEqual( + _get_number_pruning_vars(self.initial_graph_def), number_masked_layers) + self.assertEqual(_get_number_pruning_vars(self.final_graph_def), 0) + + def _check_output_equivalence(self, initial_outputs, final_outputs): + for initial_output, final_output in zip(initial_outputs, final_outputs): + self.assertAllEqual(initial_output, final_output) + + def testConvolutionalModel(self): + with self.initial_graph.as_default(): + number_masked_conv_layers = 5 + top_layer = self._build_convolutional_model(number_masked_conv_layers) + output_tensor_names = [top_layer.name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_conv_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + def testFullyConnectedModel(self): + with self.initial_graph.as_default(): + number_masked_fc_layers = 3 + top_layer = self._build_fully_connected_model(number_masked_fc_layers) + output_tensor_names = [top_layer.name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_fc_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + def testLSTMModel(self): + with self.initial_graph.as_default(): + number_masked_lstm_layers = 2 + outputs = self._build_lstm_model(number_masked_lstm_layers) + output_tensor_names = [outputs[0][0].name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_lstm_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index bbdf962d0480e52045d31f65b3d137ed3f11f2f1..778b710d78a2095b8a1315018641c67419c26b98 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -27,6 +27,7 @@ py_library( "python/training/nadam_optimizer.py", "python/training/powersign.py", "python/training/reg_adagrad_optimizer.py", + "python/training/shampoo.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", "python/training/weight_decay_optimizers.py", @@ -344,3 +345,23 @@ py_test( "//third_party/py/numpy", ], ) + +py_test( + name = "shampoo_test", + size = "large", + srcs = ["python/training/shampoo_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 3e63e99030c46c254625ca8fdccce614cd60e8b0..9471fb018162ee377e9c614d6e4d745b4282165a 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -30,10 +30,10 @@ from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.shampoo import * from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -62,6 +62,7 @@ _allowed_symbols = [ 'ModelAverageOptimizer', 'ModelAverageCustomGetter', 'GGTOptimizer', + 'ShampooOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py new file mode 100644 index 0000000000000000000000000000000000000000..a98866b180330727a8331faf1f1d9b7e398108e4 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -0,0 +1,474 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""The Shampoo Optimizer. + +Variant of Adagrad using one preconditioner matrix per variable dimension. +For details, see https://arxiv.org/abs/1802.09568 +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import optimizer + + +def GetParam(var, timestep): + if callable(var): + return var(timestep) + else: + return var + + +class ShampooOptimizer(optimizer.Optimizer): + """The Shampoo Optimizer + + Variant of Adagrad using one preconditioner matrix per variable dimension. + For details, see https://arxiv.org/abs/1802.09568 + + gbar is time-weighted accumulated gradient: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + + mat_gbar is time-weighted accumulated gradient square: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation) + + Update rule: + w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t] + Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the + j'th dimension of gbar[t] with the first dimension of + mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter, + and n = rank of the variable. + Prod_j represents doing this contraction for all j in 0..n-1. + + Typically learning_rate is constant, but could be time dependent by passing + a lambda function that depends on step. + """ + + def __init__(self, + global_step=0, + max_matrix_size=768, + gbar_decay=0.0, + gbar_weight=1.0, + mat_gbar_decay=1.0, + mat_gbar_weight=1.0, + learning_rate=1.0, + svd_interval=1, + precond_update_interval=1, + epsilon=0.1, + alpha=0.5, + use_iterative_root=False, + use_locking=False, + name="Shampoo"): + """Default values of the various hyper-parameters. + + gbar_decay, gbar_weight etc. can be a float or a time varying parameter. + For time-varying parameters use e.g. "lambda T: T / (T + 1.0)" + where the expression in the lambda is a tensorflow expression + + Args: + global_step: tensorflow variable indicating the step. + max_matrix_size: We do not perform SVD for matrices larger than this. + gbar_decay: + gbar_weight: Used to update gbar: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + mat_gbar_decay: + mat_gbar_weight: Used to update mat_gbar: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + learning_rate: Similar to SGD + svd_interval: We should do SVD after this many steps. Default = 1, i.e. + every step. Usually 20 leads to no loss of accuracy, and + 50 or 100 is also OK. May also want more often early, + and less often later - set in caller as for example: + "svd_interval = lambda(T): tf.cond( + T < 2000, lambda: 20.0, lambda: 1000.0)" + precond_update_interval: We should update the preconditioners after + this many steps. Default = 1. Usually less than + svd_interval. + epsilon: epsilon * I_n is added to each mat_gbar_j for stability + alpha: total power of the preconditioners. + use_iterative_root: should the optimizer use SVD (faster) or the + iterative root method (for TPU) for finding the + roots of PSD matrices. + use_locking: + name: name of optimizer. + """ + + super(ShampooOptimizer, self).__init__(use_locking, name) + + self._global_step = math_ops.to_float(global_step) + self._max_matrix_size = max_matrix_size + self._gbar_decay = gbar_decay + self._gbar_weight = gbar_weight + self._mat_gbar_decay = mat_gbar_decay + self._mat_gbar_weight = mat_gbar_weight + self._learning_rate = learning_rate + self._svd_interval = svd_interval + self._precond_update_interval = precond_update_interval + self._epsilon = epsilon + self._alpha = alpha + self._use_iterative_root = use_iterative_root + self._name = name + + def _create_slots(self, var_list): + for v in var_list: + with ops.colocate_with(v): + _ = self._zeros_slot(v, "gbar", self._name) + shape = np.array(v.get_shape()) + for i, d in enumerate(shape): + d_tensor = ops.convert_to_tensor(d) + if d < self._max_matrix_size: + mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor)) + if self._svd_interval > 1: + _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor), + "H_" + str(i), self._name) + else: + mat_g_init = array_ops.zeros([d_tensor]) + + _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i), + self._name) + + def _resource_apply_dense(self, grad, var): + return self._apply_dense(grad, var) + + def _apply_dense(self, grad, var): + return self._apply_gradient(grad, var) + + def _resource_apply_sparse(self, grad_values, var, grad_indices): + return self._apply_sparse_shared(grad_values, grad_indices, var) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared(grad.values, grad.indices, var) + + def _apply_sparse_shared(self, grad_values, grad_indices, var): + if var.get_shape()[0] < self._max_matrix_size or self._gbar_decay != 0.0: + # The dimension is small enough, we can make the variable dense and + # do a dense update + dense_grad = array_ops.scatter_nd( + array_ops.expand_dims(grad_indices, axis=1), grad_values, + array_ops.shape(var, out_type=grad_indices.dtype)) + return self._apply_gradient(dense_grad, var) + return self._apply_gradient(grad_values, var, grad_indices) + + def _weighted_average(self, var, weight, weight_t, rest): + """Computes exponential weighted average: var = weight_t * var + rest. + + Important to ensure that var does not occur in rest, otherwise + we can get race conditions in a distributed setting. + + Args: + var: variable to be updated + weight: parameter to be checked. If it is a constant, we can optimize. + weight_t: current value of parameter, used for weighting + rest: the remaining tensor to be added + + Returns: + updated variable. + """ + if weight == 0.0: + return rest # no need to update var, we will never use it. + if weight == 1.0: # common case + return state_ops.assign_add(var, rest) + # The op below can cause race conditions in a distributed setting, + # since computing weight_t * var + rest can take some time, during + # which var may be set by another worker. To prevent this, it should + # be implemented as a C++ op. + return var.assign_add((weight_t - 1) * var + rest) + + def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay, + mat_gbar_weight, i): + """Updates the cumulative outer products of the gradients. + + Args: + mat_g: the matrix to be updated + grad: the gradient of the variable + axes: a list of k-1 integers 0 to k-1, except i + mat_gbar_decay: constant for weighted average: + mat_g = mat_g * decay + grad * weight + mat_gbar_weight: constant for weighted average + i: index of dimension to be updated. + + Returns: + updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight + + In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd + thus grad_outer is a matrix d_i x d_i, where d_i is the size of the + i'th dimension of g. + Alternate view: If mat_i(grad) is the flattening of grad to a + d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then + grad_outer = mat_i(grad) mat_i(grad).transpose + """ + grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes), + name="grad_outer_" + str(i)) + return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay, + mat_gbar_weight * grad_outer) + + def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name): + """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix. + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g + alpha: a real number + mat_h_slot_name: name of slot to store the power, if needed. + + Returns: + mat_h = mat_g^alpha + + Stores mat_h in the appropriate slot, if it exists. + Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig. + """ + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size)) + diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) + mat_h = math_ops.matmul( + mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), + array_ops.transpose(mat_u)) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, + iter_count=100, epsilon=1e-6): + """Computes mat_g^alpha, where alpha = -1/p, p a positive integer. + + We use an iterative Schur-Newton method from equation 3.2 on page 9 of: + + A Schur-Newton Method for the Matrix p-th Root and its Inverse + by Chun-Hua Guo and Nicholas J. Higham + SIAM Journal on Matrix Analysis and Applications, + 2006, Vol. 28, No. 3 : pp. 788-804 + https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g. + alpha: exponent, must be -1/p for p a positive integer. + mat_h_slot_name: name of slot to store the power, if needed. + iter_count: Maximum number of iterations. + epsilon: accuracy indicator, useful for early termination. + + Returns: + mat_g^alpha + """ + + identity = linalg_ops.eye(math_ops.to_int32(mat_g_size)) + + def MatPower(mat_m, p): + """Computes mat_m^p, for p a positive integer. + + Power p is known at graph compile time, so no need for loop and cond. + Args: + mat_m: a square matrix + p: a positive integer + + Returns: + mat_m^p + """ + assert p == int(p) and p > 0 + power = None + while p > 0: + if p % 2 == 1: + power = math_ops.matmul(mat_m, power) if power is not None else mat_m + p //= 2 + mat_m = math_ops.matmul(mat_m, mat_m) + return power + + def IterCondition(i, mat_m, _): + return math_ops.logical_and( + i < iter_count, + math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon) + + def IterBody(i, mat_m, mat_x): + mat_m_i = (1 - alpha) * identity + alpha * mat_m + return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m), + math_ops.matmul(mat_x, mat_m_i)) + + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damped_mat_g = mat_g + self._epsilon * identity + z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g)) + # The best value for z is + # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) / + # (c_max^{1-alpha} - c_min^{1-alpha}) + # where c_max and c_min are the largest and smallest singular values of + # damped_mat_g. + # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha) + # Can replace above line by the one below, but it is less accurate, + # hence needs more iterations to converge. + # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g) + # If we want the method to always converge, use z = 1 / norm(damped_mat_g) + # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many + # extra iterations. + _, _, mat_h = control_flow_ops.while_loop( + IterCondition, IterBody, + [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)]) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None): + """Just a switch between the iterative power vs svd.""" + with ops.name_scope("matrix_iterative_power"): + if self._use_iterative_root: + return self._compute_power_iter(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + else: + return self._compute_power_svd(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + + def _apply_gradient(self, grad, var, indices=None): + """The main function to update a variable. + + Args: + grad: A Tensor containing gradient to apply. + var: A Tensor containing the variable to update. + indices: An array of integers, for sparse update. + + Returns: + Updated variable var = var - learning_rate * preconditioner * grad + + If the gradient is dense, var and grad have the same shape. + If the update is sparse, then the first dimension of the gradient and var + may differ, others are all the same. In this case the indices array + provides the set of indices of the variable which are to be updated with + each row of the gradient. + """ + global_step = self._global_step + 1 + + # Update accumulated weighted average of gradients + gbar = self.get_slot(var, "gbar") + gbar_decay_t = GetParam(self._gbar_decay, global_step) + gbar_weight_t = GetParam(self._gbar_weight, global_step) + if indices is not None: + # Note - the sparse update is not easily implemented, since the + # algorithm needs all indices of gbar to be updated + # if mat_gbar_decay != 1 or mat_gbar_decay != 0. + # One way to make mat_gbar_decay = 1 is by rescaling. + # If we want the update: + # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t + # define: + # r_{t+1} = a_{t+1} * r_t + # h_t = G_t / r_t + # Then: + # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t + # So we get the mat_gbar_decay = 1 as desired. + # We can implement this in a future version as needed. + # However we still need gbar_decay = 0, otherwise all indices + # of the variable will need to be updated. + if self._gbar_decay != 0.0: + tf_logging.warning("Not applying momentum for variable: %s" % var.name) + gbar_updated = grad + else: + gbar_updated = self._weighted_average(gbar, self._gbar_decay, + gbar_decay_t, + gbar_weight_t * grad) + + # Update the preconditioners and compute the preconditioned gradient + shape = var.get_shape() + mat_g_list = [] + for i in range(len(shape)): + mat_g_list.append(self.get_slot(var, "Gbar_" + str(i))) + mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step) + mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step) + + preconditioned_grad = gbar_updated + v_rank = len(mat_g_list) + neg_alpha = - GetParam(self._alpha, global_step) / v_rank + svd_interval = GetParam(self._svd_interval, global_step) + precond_update_interval = GetParam(self._precond_update_interval, + global_step) + for i, mat_g in enumerate(mat_g_list): + # axes is the list of indices to reduce - everything but the current i. + axes = list(range(i)) + list(range(i+1, v_rank)) + if shape[i] < self._max_matrix_size: + # If the tensor size is sufficiently small perform full Shampoo update + # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this + # is not strictly correct. However we will use it for now, and + # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg) + + # pylint: disable=g-long-lambda,cell-var-from-loop + mat_g_updated = control_flow_ops.cond( + math_ops.mod(global_step, precond_update_interval) < 1, + lambda: self._update_mat_g( + mat_g, grad, axes, mat_gbar_decay_t, + mat_gbar_weight_t * precond_update_interval, i), + lambda: mat_g) + + if self._svd_interval == 1: + mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) + else: + mat_h = control_flow_ops.cond( + math_ops.mod(global_step, svd_interval) < 1, + lambda: self._compute_power(var, mat_g_updated, shape[i], + neg_alpha, "H_" + str(i)), + lambda: self.get_slot(var, "H_" + str(i))) + + # mat_h is a square matrix of size d_i x d_i + # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor + # After contraction with a d_i x d_i tensor + # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor + # (the first dimension is contracted out, and the second dimension of + # mat_h is appended). After going through all the indices, it becomes + # a d_0 x ... x d_n tensor again. + preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h, + axes=([0], [0]), + name="precond_" + str(i)) + else: + # Tensor size is too large -- perform diagonal Shampoo update + grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) + if i == 0 and indices is not None: + assert self._mat_gbar_decay == 1.0 + mat_g_updated = state_ops.scatter_add(mat_g, indices, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow( + array_ops.gather(mat_g_updated, indices) + self._epsilon, + neg_alpha) + else: + mat_g_updated = self._weighted_average(mat_g, + self._mat_gbar_decay, + mat_gbar_decay_t, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha) + + # Need to do the transpose to ensure that the tensor becomes + # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above. + preconditioned_grad = array_ops.transpose( + preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h + + # Update the variable based on the Shampoo update + learning_rate_t = GetParam(self._learning_rate, global_step) + if indices is not None: + var_updated = state_ops.scatter_add( + var, indices, -learning_rate_t * preconditioned_grad) + else: + var_updated = state_ops.assign_sub(var, + learning_rate_t * preconditioned_grad) + return var_updated diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0a202ae293664d85ece884a505096455cde73c --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo_test.py @@ -0,0 +1,734 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional tests for AdaMoo optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import shampoo +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +TOLERANCE = 1e-3 + + +def np_power(mat_g, alpha): + """Computes mat_g^alpha for a square symmetric matrix mat_g.""" + + mat_u, diag_d, mat_v = np.linalg.svd(mat_g) + diag_d = np.power(diag_d, alpha) + return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v) + + +class ShampooTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testBasicVector(self, use_resource_var): + """Similar to the full Adagrad update.""" + + size = 20 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g^{-0.5} * grad + # lr = 1 + mat_g = np.outer(grad_np, grad_np) + mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + new_val_np = init_var_np - np.dot(mat_h, grad_np) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += np.outer(grad_np_2, grad_np_2) + mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + new_val_np -= np.dot(mat_h, grad_np_2) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testBasicMatrix(self, use_resource_var): + """Check update when gradient is a matrix.""" + size = [10, 5] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25} + # lr = 1 + mat_g1 = np.dot(grad_np, grad_np.transpose()) + mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) + mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testBasicTensor(self, use_iterative_root, use_resource_var): + """Check update when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testBasicTensor(self, use_iterative_root, use_resource_var): + self._testBasicTensor(use_iterative_root, use_resource_var) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testLargeVector(self, use_resource_var): + """This is just the diagonal Adagrad update.""" + + size = 2000 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * gg^{-0.5} * grad + # lr = 1 + mat_g = grad_np * grad_np + 0.1 + new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np + + self.assertAllCloseAccordingToType(new_val_np, new_val) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += grad_np_2 * grad_np_2 + new_val_np -= np.power(mat_g, -0.5) * grad_np_2 + + self.assertAllCloseAccordingToType(new_val_np, new_val) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testLargeMatrix(self, use_resource_var): + """Gradient is a matrix, one of whose dimensions is large. + + We do diagonal updates for large dimensions. + + Args: + use_resource_var: use resource var as variables. + """ + + size = [2000, 3] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters(('Var', False)) + def testSparseUpdateLarge(self, use_resource_var): + """Check update when gradient is of type IndexSlices. + + We do diagonal updates for the first dimension, unless it is very small. + + Args: + use_resource_var: use resource var as variables. + """ + size = [2000, 3] + sample_size_1 = 100 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size_1, + replace=False)) + grad_np = np.random.rand(sample_size_1, size[1]) + + sample_size_2 = 7 + grad_indices_2 = np.sort(np.random.choice(np.arange(size[0]), sample_size_2, + replace=False)) + grad_np_2 = np.random.rand(sample_size_2, size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + grad_2 = ops.IndexedSlices( + constant_op.constant(grad_np_2, dtype=dtypes.float32), + constant_op.constant(grad_indices_2), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + # In this case the update lr * mat_left * grad * mat_right is + # of size 10 x 2. + # So the correct indices of var need to be updated. + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_g1_acc = np.zeros((size[0], 1)) + mat_g1_acc[grad_indices] += mat_g1 + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np + new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_g1_acc[grad_indices_2] += mat_g1 + mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testSparseUpdateSmall(self, use_iterative_root, use_resource_var): + """Gradient is of type IndexSlices, but the first dimension is small. + + We create dense gradient and do the full update with SVD etc. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + + size = [100, 3, 5] + sample_size = 10 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size, + replace=False)) + grad_np = np.random.rand(sample_size, size[1], size[2]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.125} grad + # lr = 1 + grad_dense = np.zeros_like(init_var_np) + grad_dense[grad_indices] = grad_np + + mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testSparseUpdateSmall(self, use_iterative_root, use_resource_var): + self._testSparseUpdateSmall(use_iterative_root, use_resource_var) + + def _testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var): + """Check update with momentum when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + gbar_decay = 0.9 + gbar_weight = 0.1 + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, gbar_decay=gbar_decay, + gbar_weight=gbar_weight, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + gbar_np = gbar_weight * grad_np + precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2 + precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var): + self._testBasicTensorWithMomentum(use_iterative_root, use_resource_var) + + def _testDelayedSVD(self, use_iterative_root, use_resource_var): + """Performing the SVD every nth step. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 20 + svd_interval = 5 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer(global_step, svd_interval=svd_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) + mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) + mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testDelayedSVD(self, use_iterative_root, use_resource_var): + self._testDelayedSVD(use_iterative_root, use_resource_var) + + def _testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var): + """Update the squared sum every nth step, drop the other steps. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 100 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + svd_interval = 20 + precond_update_interval = 5 + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer( + global_step, svd_interval=svd_interval, + precond_update_interval=precond_update_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + if (i + 1) % precond_update_interval == 0: + mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) + * precond_update_interval) + mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) + * precond_update_interval) + mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + * precond_update_interval) + + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var): + self._testDelayedPrecondUpdate(use_iterative_root, use_resource_var) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD index 5225ecc14fef3cec9506eceb776805b74a87714e..3ba3ee29ec79687df522eb330665a2ce80061682 100644 --- a/tensorflow/contrib/optimizer_v2/BUILD +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -193,6 +193,7 @@ cuda_py_test( srcs = ["rmsprop_test.py"], additional_deps = [ ":training", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 06ab58188a2fffa0e3a810d451875ca951a077b9..28a531dfecf275c48fea54310b93b5266a79899a 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import tracking @@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) - root.restore(core_saver.latest_checkpoint(checkpoint_directory)) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) @@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize( model(input_value), global_step=root.global_step) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) @@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( @@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @function.defun diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index ec033c4a0163ba9ed39e55fa9e92dfdadc9a1b2f..a44bfd1bfd97e678fbf4c402ef5b1298dc518f75 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -38,12 +38,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testBasic(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop # Note that for eager execution, minimize expects a function instead of a @@ -131,12 +127,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) # pylint: disable=cell-var-from-loop def loss(): return 5 * var0 @@ -149,12 +141,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return constant_op.constant(5.0) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) @@ -165,12 +153,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) with self.assertRaisesRegexp(ValueError, 'No gradients provided for any variable'): @@ -179,12 +163,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop sgd_op = gradient_descent.GradientDescentOptimizer(3.0) diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py index ed68f6afbf8bf9678649c1ce6fc59c3b91026dc0..dc23ef241a43900ed40f029f1b857820459e43d0 100644 --- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py +++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py @@ -19,15 +19,16 @@ from __future__ import division from __future__ import print_function import copy -import itertools import math +from absl.testing import parameterized import numpy as np from tensorflow.contrib.optimizer_v2 import rmsprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -48,13 +49,8 @@ _TEST_PARAM_VALUES = [ [0.5, 0.95, 0.9, 1e-5, True, False], ] -_TESTPARAMS = [ - [data_type] + values - for data_type, values in itertools.product(_DATA_TYPES, _TEST_PARAM_VALUES) -] - -class RMSPropOptimizerTest(test.TestCase): +class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon, centered): @@ -87,362 +83,366 @@ class RMSPropOptimizerTest(test.TestCase): var_t[gindex] = var[gindex] - mom_t[gindex] return var_t, mg_t, rms_t, mom_t - def testDense(self): - # TODO(yori): Use ParameterizedTest when available - for (dtype, learning_rate, decay, momentum, - epsilon, centered, use_resource) in _TESTPARAMS: - with self.test_session(use_gpu=True): - # Initialize variables for numpy implementation. - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype) - - if use_resource: - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - else: - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - opt = rmsprop.RMSPropOptimizer( - learning_rate=learning_rate, - decay=decay, - momentum=momentum, - epsilon=epsilon, - centered=centered) - - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - mg0 = opt.get_slot(var0, "mg") - self.assertEqual(mg0 is not None, centered) - mg1 = opt.get_slot(var1, "mg") - self.assertEqual(mg1 is not None, centered) - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 4 steps of RMSProp - for _ in range(1, 5): - update.run() - - var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( - var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, - decay, momentum, epsilon, centered) - var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( - var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, - decay, momentum, epsilon, centered) - - # Validate updated params - if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testMinimizeSparseResourceVariable(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) - loss = pred * pred - sgd_op = rmsprop.RMSPropOptimizer( - learning_rate=1.0, - decay=0.0, - momentum=0.0, - epsilon=0.0, - centered=False).minimize(loss) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[0., 1.]], var0.eval(), atol=0.01) - - def testMinimizeSparseResourceVariableCentered(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) - loss = pred * pred - sgd_op = rmsprop.RMSPropOptimizer( - learning_rate=1.0, - decay=0.0, - momentum=0.0, - epsilon=1.0, - centered=True).minimize(loss) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[-111, -138]], var0.eval(), atol=0.01) - - def testSparse(self): - # TODO(yori): Use ParameterizedTest when available - for (dtype, learning_rate, decay, - momentum, epsilon, centered, _) in _TESTPARAMS: - with self.test_session(use_gpu=True): - # Initialize variables for numpy implementation. - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype) - + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES)) + def testDense(self, dtype, param_value): + (learning_rate, decay, momentum, epsilon, centered, use_resource) = tuple( + param_value) + with self.test_session(use_gpu=True): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) - grads0_np_indices = np.array([0], dtype=np.int32) - grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), - constant_op.constant(grads0_np_indices), constant_op.constant([1])) - grads1_np_indices = np.array([1], dtype=np.int32) - grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np), - constant_op.constant(grads1_np_indices), constant_op.constant([1])) - opt = rmsprop.RMSPropOptimizer( - learning_rate=learning_rate, - decay=decay, - momentum=momentum, - epsilon=epsilon, - centered=centered) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - mg0 = opt.get_slot(var0, "mg") - self.assertEqual(mg0 is not None, centered) - mg1 = opt.get_slot(var1, "mg") - self.assertEqual(mg1 is not None, centered) - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 4 steps of RMSProp - for _ in range(1, 5): - update.run() - - var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( - var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, - learning_rate, decay, momentum, epsilon, centered) - var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( - var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, - learning_rate, decay, momentum, epsilon, centered) - - # Validate updated params - if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testWithoutMomentum(self): - for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(use_gpu=True): - var0 = variables.Variable([1.0, 2.0], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - opt = rmsprop.RMSPropOptimizer( - learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Step 1: the rms accumulators where 1. So we should see a normal - # update: v -= grad * learning_rate - update.run() - # Check the root mean square accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901, 0.901]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001, 0.90001]), rms1.eval()) - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - ]), var1.eval()) - # Step 2: the root mean square accumulators contain the previous update. - update.run() - # Check the rms accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) - ]), var1.eval()) - - def testWithMomentum(self): - for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(use_gpu=True): - var0 = variables.Variable([1.0, 2.0], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - - opt = rmsprop.RMSPropOptimizer( - learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Step 1: rms = 1, mom = 0. So we should see a normal - # update: v -= grad * learning_rate + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = rmsprop.RMSPropOptimizer( + learning_rate=learning_rate, + decay=decay, + momentum=momentum, + epsilon=epsilon, + centered=centered) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for _ in range(4): update.run() - # Check the root mean square accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901, 0.901]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001, 0.90001]), rms1.eval()) - # Check the momentum accumulators - self.assertAllCloseAccordingToType( - np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval()) - self.assertAllCloseAccordingToType( - np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval()) - - # Check that the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - ]), var1.eval()) - - # Step 2: the root mean square accumulators contain the previous update. + + var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( + var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, + decay, momentum, epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( + var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, + decay, momentum, epsilon, centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([dtypes.float32, dtypes.float64]) + def testMinimizeSparseResourceVariable(self, dtype): + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = rmsprop.RMSPropOptimizer( + learning_rate=1.0, + decay=0.0, + momentum=0.0, + epsilon=0.0, + centered=False).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[0., 1.]], var0.eval(), atol=0.01) + + @parameterized.parameters([dtypes.float32, dtypes.float64]) + def testMinimizeSparseResourceVariableCentered(self, dtype): + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = rmsprop.RMSPropOptimizer( + learning_rate=1.0, + decay=0.0, + momentum=0.0, + epsilon=1.0, + centered=True).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[-111, -138]], var0.eval(), atol=0.01) + + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES)) + def testSparse(self, dtype, param_value): + (learning_rate, decay, momentum, epsilon, centered, _) = tuple( + param_value) + with self.test_session(use_gpu=True): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([1])) + grads1_np_indices = np.array([1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([1])) + opt = rmsprop.RMSPropOptimizer( + learning_rate=learning_rate, + decay=decay, + momentum=momentum, + epsilon=epsilon, + centered=centered) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for _ in range(4): update.run() - # Check the rms accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), - 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) - ]), mom0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), - 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) - ]), mom1.eval()) - - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - - (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - - (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) - ]), var0.eval()) - - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - - (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - - (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) - ]), var1.eval()) + + var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( + var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, + learning_rate, decay, momentum, epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( + var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, + learning_rate, decay, momentum, epsilon, centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters(_DATA_TYPES) + def testWithoutMomentum(self, dtype): + with self.test_session(use_gpu=True): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + opt = rmsprop.RMSPropOptimizer( + learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the rms accumulators where 1. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901, 0.901]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001, 0.90001]), rms1.eval()) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) + ]), var1.eval()) + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) + ]), var1.eval()) + + @parameterized.parameters(_DATA_TYPES) + def testWithMomentum(self, dtype): + with self.test_session(use_gpu=True): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + + opt = rmsprop.RMSPropOptimizer( + learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: rms = 1, mom = 0. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901, 0.901]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001, 0.90001]), rms1.eval()) + # Check the momentum accumulators + self.assertAllCloseAccordingToType( + np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), + (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval()) + self.assertAllCloseAccordingToType( + np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), + (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval()) + + # Check that the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + ]), var1.eval()) + + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), + 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) + ]), mom0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) + ]), mom1.eval()) + + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) + ]), var0.eval()) + + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) + ]), var1.eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py index af3b2ad1b531b835f484a155efcc57bbe634f2df..c2166594e598857065a7fd109ec599a3b36e2d2c 100644 --- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -22,8 +22,8 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.contrib.predictor import predictor from tensorflow.python.framework import ops +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver class ContribEstimatorPredictor(predictor.Predictor): @@ -57,7 +57,8 @@ class ContribEstimatorPredictor(predictor.Predictor): # pylint: disable=protected-access model_fn_ops = estimator._get_predict_ops(input_fn_ops.features) # pylint: enable=protected-access - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( config=config, diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py index f275bc15adfa0a51a48964dff8edddbd45500e45..7886744b3ce7fc438bc73cb81bccfd0ddeea873e 100644 --- a/tensorflow/contrib/predictor/predictor_factories.py +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -108,6 +108,8 @@ def from_estimator(estimator, def from_saved_model(export_dir, signature_def_key=None, signature_def=None, + input_names=None, + output_names=None, tags=None, graph=None, config=None): @@ -121,6 +123,12 @@ def from_saved_model(export_dir, signature_def: A `SignatureDef` proto specifying the inputs and outputs for prediction. Only one of `signature_def_key` and `signature_def` should be specified. + input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel` + that represent the input. The keys can be any string of the user's + choosing. + output_names: A dictionary mapping strings to `Tensor`s in the + `SavedModel` that represent the output. The keys can be any string of + the user's choosing. tags: Optional. Tags that will be used to retrieve the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be @@ -138,6 +146,8 @@ def from_saved_model(export_dir, export_dir, signature_def_key=signature_def_key, signature_def=signature_def, + input_names=input_names, + output_names=output_names, tags=tags, graph=graph, config=config) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index e3c48998305e9d9b6c185fd4c0f324fa0449c691..d9f179bee48de587976872dabb470cfd5c69114c 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -120,6 +120,7 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): scaled_weight_tensor = math_ops.multiply( weights, multiplier_tensor, name='mul_fold') + new_layer_tensor = _CloneWithNewOperands( match.layer_op, match.input_tensor, scaled_weight_tensor, match.batch_to_space_op) @@ -368,20 +369,20 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') + graph_editor.reroute_ts( [bn_decay_mean_out], [match.bn_decay_mean_tensor], can_modify=bn_decay_mean_consumers) - if fused_batch_norm is False: - bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) - bn_decay_var_out = utils.smart_cond( - use_mv_avg, - lambda: bn_decay_zero, - lambda: match.bn_decay_var_tensor, - name='freeze_moving_var') - graph_editor.reroute_ts( - [bn_decay_var_out], [match.bn_decay_var_tensor], - can_modify=bn_decay_var_consumers) + bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) + bn_decay_var_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_var_tensor, + name='freeze_moving_var') + graph_editor.reroute_ts( + [bn_decay_var_out], [match.bn_decay_var_tensor], + can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index 7c907ffd92c1ae0c762e41cc429b0e6ce053f6b9..3f8063cc022726cb745d42aba3c834c71e876e70 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -128,6 +128,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) + for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -216,6 +219,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = [scope + '/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -284,6 +289,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -351,6 +358,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -431,6 +440,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -515,6 +526,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -644,6 +657,22 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_op = graph.get_operation_by_name(out_op_name) self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + def _AssertMovingAveragesAreFrozen(self, graph, scope): + """Asserts to check if moving mean and variance are frozen. + + Args: + graph: Graph where the operations are located. + scope: Scope of batch norm op + """ + moving_average_mult = graph.get_operation_by_name( + scope + '/BatchNorm/AssignMovingAvg/mul') + self.assertTrue( + moving_average_mult.inputs[1].name.find('freeze_moving_mean/Merge') > 0) + moving_var_mult = graph.get_operation_by_name( + scope + '/BatchNorm/AssignMovingAvg_1/mul') + self.assertTrue( + moving_var_mult.inputs[1].name.find('freeze_moving_var/Merge') > 0) + def _CopyGraph(self, graph): """Return a copy of graph.""" meta_graph = saver_lib.export_meta_graph( diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index c2a8def48012c808da18587c8ff462fa33a363c0..a45840009b758881c14fb64b2d39af6cd4ec4bc4 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -75,7 +75,7 @@ class QuantOpsTest(googletest.TestCase): self.assertGreater(max_value, 0.0) self.assertLess(max_value, 1.0) - def testVariablesNotParitioned_LastValue(self): + def testVariablesNotPartitioned_LastValue(self): # Variables added should not use a default partiioner since they are # scalar. There would be a tensorflow error thrown if the partitioner was # respected by the rewrite. @@ -90,7 +90,7 @@ class QuantOpsTest(googletest.TestCase): is_training=True, vars_collection=_MIN_MAX_VARS) - def testVariablesNotParitioned_MovingAvg(self): + def testVariablesNotPartitioned_MovingAvg(self): # Variables added should not use a default partiioner since they are # scalar. There would be a tensorflow error thrown if the partitioner was # respected by the rewrite. diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 4fc315d901a86ac235513aad6eb34d7f90f61801..cb66fd1f76bcdb0a8f77fc7c476511576368ab4e 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -198,7 +198,7 @@ def _FindLayersToQuantize(graph): | [post_conv_correction] | - biasadd|folded_bias + [biasadd|folded_bias] | [bypass] | @@ -261,6 +261,16 @@ def _FindLayersToQuantize(graph): layer_output_pattern = graph_matcher.OneofPattern( [batch_to_space_pattern, layer_pattern]) + + # For separable convolutions, we are looking for a conv, followed by a conv + # with no activations between the two. + sep_conv_pattern = graph_matcher.OpTypePattern( + '|'.join(_QUANTIZABLE_TYPES), + inputs=[ + graph_matcher.OneofPattern([layer_output_pattern]), + graph_matcher.OpTypePattern('*') + ], + ordered_inputs=False) folded_bias_mul_pattern = graph_matcher.OpTypePattern( 'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern], @@ -310,6 +320,7 @@ def _FindLayersToQuantize(graph): folded_bias_add_pattern, batch_norm_identity, bypass_pattern, + layer_pattern, ]) ]) @@ -393,6 +404,17 @@ def _FindLayersToQuantize(graph): layer_matches.append( _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) + # Look for separable convolutions here + sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern) + for match_result in sep_conv_matcher.match_graph(graph): + layer_op = match_result.get_op(layer_pattern) + weight_tensor = match_result.get_tensor(weight_identity_pattern) + activation_op = match_result.get_op(layer_pattern) + if layer_op not in matched_layer_set: + matched_layer_set.add(layer_op) + layer_matches.append( + _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) + return layer_matches diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 92ca4a1b0c3126ebccf2b525f01f4d6455c4d527..06ebcdfee1617af0c13cd6ed09a2ec5190c5a718 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -122,12 +122,67 @@ class QuantizeTest(test_util.TensorFlowTestCase): array_ops.identity(node, name='control_dependency') quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check if output of bias add is quantized + quantization_node_name = 'FakeQuantWithMinMaxVars' + conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + + for op in graph.get_operations(): + if op.type == quantization_node_name: + quant_op = graph.get_operation_by_name(op.name) + # Scan through all FakeQuant operations, ensuring that the activation + # identity op isn't in the consumers of the operation. + consumers = [] + for output in quant_op.outputs: + consumers.extend(output.consumers()) + + self.assertNotIn('test/relu6', [c.name for c in consumers]) + + def testInsertQuantOpInSeparableConv2d(self): + self._RunTestOverParameters(self._TestInsertQuantOpInSeparableConv2d) + + def _TestInsertQuantOpInSeparableConv2d(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth)) + conv = separable_conv2d( + input1, + 3, [5, 5], + stride=2, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = nn_ops.relu6(node, name='test/relu6') + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check if output of bias add is quantized quantization_node_name = 'FakeQuantWithMinMaxVars' conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) + # Check if weights for both convs inside seperable conv are quantized + pointwise_weight_quant = graph.get_operation_by_name( + 'test/test/weights_quant/' + quantization_node_name) + self.assertEqual(pointwise_weight_quant.type, quantization_node_name) + depthwise_weight_quant = graph.get_operation_by_name( + 'test/test/separable_conv2d/weights_quant/' + quantization_node_name) + self.assertEqual(depthwise_weight_quant.type, quantization_node_name) + + # Check if activations after first depthwise conv are quantized. + depthwise_act_quant = graph.get_operation_by_name( + 'test/test/separable_conv2d/act_quant/' + quantization_node_name) + self.assertEqual(depthwise_act_quant.type, quantization_node_name) + for op in graph.get_operations(): if op.type == quantization_node_name: quant_op = graph.get_operation_by_name(op.name) @@ -139,6 +194,33 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertNotIn('test/relu6', [c.name for c in consumers]) + def testLayerActivationQuantized(self): + self._RunTestOverParameters(self._TestLayerActivationQuantized) + + def _TestLayerActivationQuantized(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + biases_initializer=None, + scope='test') + # Ensure that both weights and output of activations are quantized + # when we have a conv->relu6 with no bias add + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + activation_op = graph.get_operation_by_name('test/Relu6') + conv_op = graph.get_operation_by_name('test/Conv2D') + self.assertTrue('test/weights_quant/FakeQuantWithMinMaxVars:0' in + [tensor_in.name for tensor_in in conv_op.inputs]) + self.assertTrue('FakeQuantWithMinMaxVars' in + [op.type for op in activation_op.outputs[0].consumers()]) + def testFinalLayerQuantized(self): self._RunTestOverParameters(self._TestFinalLayerQuantized) diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py index 0f19ac7dbe0cee2eb6c780ec5ea6266bc847abd7..f23194a6f2e64e0619049bac51891d6d6099831f 100644 --- a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py +++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py @@ -61,10 +61,17 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): func, args = self._CELLDEFS[celldef_name] return func(*args) - def _CreateInputs(self): - inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE, - FunctionalRnnTest._TOTAL_TIME, - FunctionalRnnTest._INPUT_SIZE]) + def _CreateInputs(self, time_major=False): + if time_major: + inputs = np.random.random([ + FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._INPUT_SIZE + ]) + else: + inputs = np.random.random([ + FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME, + FunctionalRnnTest._INPUT_SIZE + ]) # Always leave one time slot empty, to check max_length behavior. sequence_length = np.random.randint( 0, high=FunctionalRnnTest._TOTAL_TIME - 1, @@ -72,15 +79,51 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): dtype=np.int) return (inputs, sequence_length) - def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs, - tf_sequence_length, initial_state=None, - time_major=None, scope=None): - tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs, - sequence_length=tf_sequence_length, - initial_state=initial_state, - dtype=dtypes.float32, - time_major=time_major, - scope=scope) + def _CreateSymmetricInputs(self): + # total time = batch size + inputs = np.zeros( + (FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._INPUT_SIZE)) + for i in range(FunctionalRnnTest._BATCH_SIZE): + for j in range(i, FunctionalRnnTest._BATCH_SIZE): + inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE]) + inputs[j][i] = inputs[i][j] + + # Always leave one time slot empty, to check max_length behavior. + sequence_length = np.random.randint( + 0, + high=FunctionalRnnTest._BATCH_SIZE - 1, + size=FunctionalRnnTest._BATCH_SIZE, + dtype=np.int) + return (inputs, sequence_length) + + def _CreateRnnGraph(self, + create_rnn_computation_func, + cell, + tf_inputs, + tf_sequence_length, + is_bidirectional, + initial_state=None, + time_major=None, + scope=None): + if is_bidirectional: + tf_result = create_rnn_computation_func( + cell_fw=cell, + cell_bw=cell, + inputs=tf_inputs, + sequence_length=tf_sequence_length, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) + else: + tf_result = create_rnn_computation_func( + cell=cell, + inputs=tf_inputs, + sequence_length=tf_sequence_length, + initial_state=initial_state, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) grad = gradients_impl.gradients(tf_result, variables.trainable_variables()) return {'inference': tf_result, 'grad': grad} @@ -102,15 +145,26 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): variable_cache[n] = v def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache, - is_dynamic): + is_dynamic, time_major=None, is_bidirectional=False): with ops.Graph().as_default() as graph: tf_inputs = array_ops.placeholder( dtypes.float32, shape=numpy_inputs.shape) tf_slen = array_ops.placeholder(dtypes.int32) feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen} cell = self._CreateCell(cell_name) - fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn - fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen) + if is_dynamic: + if is_bidirectional: + fn = rnn_lib.bidirectional_dynamic_rnn + else: + fn = rnn_lib.dynamic_rnn + else: + if is_bidirectional: + fn = functional_rnn.bidirectional_functional_rnn + else: + fn = functional_rnn.functional_rnn + + fetches = self._CreateRnnGraph( + fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major) with self.test_session(graph=graph) as sess: sess.run(variables.global_variables_initializer()) # Note that cell.trainable_variables it not always set. @@ -158,6 +212,78 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + def testLstmWithTimeMajorInputs(self): + """Checks an LSTM against the reference implementation, with time_major.""" + time_major = True + np_inputs, np_slen = self._CreateInputs(time_major=True) + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major) + _, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testBidirectionalLstmWithTimeMajorInputs(self): + """Checks a bi-directional LSTM with time-major inputs.""" + time_major = True + np_inputs, np_slen = self._CreateInputs(time_major) + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, func_rnn = self._RunRnn( + *(args + [False]), time_major=time_major, is_bidirectional=True) + _, dyn_rnn = self._RunRnn( + *(args + [True]), time_major=time_major, is_bidirectional=True) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + # TODO(b/112170761): comment out this line after the bug is fixed. + # self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testBidirectionalLstm(self): + """Checks time-major and batch-major rnn produce consistent results.""" + time_major_inputs, np_slen = self._CreateInputs(True) + batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2]) + var_cache = {} + args = [np_slen, 'lstm', var_cache, False] + _, time_major_rnn = self._RunRnn( + *([time_major_inputs] + args), time_major=True, is_bidirectional=True) + _, batch_major_rnn = self._RunRnn( + *([batch_major_inputs]+ args), time_major=False, is_bidirectional=True) + # Convert the batch-major outputs to be time-major before the comparasion. + outputs, state = batch_major_rnn['inference'] + outputs = [np.transpose(x, [1, 0, 2]) for x in outputs] + batch_major_rnn['inference'] = [outputs, state] + self.assertAllClose(time_major_rnn['inference'], + batch_major_rnn['inference']) + self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad']) + + def testBidirectionalLstmWithSymmetricInputs(self): + """Checks a bi-directional LSTM with symmetric inputs. + + time-major and batch-major rnn produce the same result with symmetric + inputs. + """ + np_inputs, np_slen = self._CreateSymmetricInputs() + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, time_major_func_rnn = self._RunRnn( + *(args + [False]), time_major=True, is_bidirectional=True) + _, batch_major_func_rnn = self._RunRnn( + *(args + [False]), time_major=False, is_bidirectional=True) + _, time_major_dyn_rnn = self._RunRnn( + *(args + [True]), time_major=True, is_bidirectional=True) + _, batch_major_dyn_rnn = self._RunRnn( + *(args + [True]), time_major=False, is_bidirectional=True) + self.assertAllClose(time_major_func_rnn['inference'], + batch_major_func_rnn['inference']) + self.assertAllClose(time_major_func_rnn['grad'], + batch_major_func_rnn['grad']) + self.assertAllClose(time_major_dyn_rnn['inference'], + batch_major_dyn_rnn['inference']) + self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad']) + self.assertAllClose(time_major_func_rnn['inference'], + batch_major_dyn_rnn['inference']) + self.assertAllClose(time_major_func_rnn['grad'], + batch_major_dyn_rnn['grad']) + if __name__ == '__main__': test_lib.main() diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index a085474c1bf6117ba5663139c78d8f08f71392d3..67a8f59c3c03d01a5957a9eff8bd026e70770a45 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -206,7 +206,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length): lengths = array_ops.tile(array_ops.reshape(sequence_length, [-1, 1]), [1, max_time]) last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1), - dtype=dtypes.float32) + dtype=state_var.dtype) last_idx = array_ops.transpose(last_idx) last_idx_for_bcast = array_ops.expand_dims(last_idx, -1) sliced = math_ops.multiply(last_idx_for_bcast, state_var) @@ -284,8 +284,13 @@ def functional_rnn(cell, inputs, sequence_length=None, inputs=inputs, cell_fn=func_cell.cell_step, use_tpu=use_tpu) - return _PostProcessOutput(extended_acc_state, extended_final_state, - func_cell, inputs_flat[0].shape[0], sequence_length) + tf_output, tf_state = _PostProcessOutput( + extended_acc_state, extended_final_state, func_cell, + inputs_flat[0].shape[0], sequence_length) + + if time_major: + tf_output = array_ops.transpose(tf_output, [1, 0, 2]) + return tf_output, tf_state def bidirectional_functional_rnn( diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py index fa16b82ab62f27d034c3ca7584e7e1ca14be6f9b..4f289e0c85e2260a44a8ea2f3f1d6cacbc839f66 100644 --- a/tensorflow/contrib/recurrent/python/ops/recurrent.py +++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py @@ -79,7 +79,7 @@ def _Index(struct, index): """ index = ops.convert_to_tensor(index) index.get_shape().assert_has_rank(0) - return nest.map_structure(lambda x: x[index], struct) + return nest.map_structure(lambda x: array_ops.gather(x, index), struct) def _Update(struct_acc, struct_x, t): diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 26fd4e2023806765ea4088f4c13a780ca7338bff..fbb50befdfb2ccbd97465c11f8219e604a0ebc18 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -93,3 +93,32 @@ py_test( "//tensorflow/python/saved_model:utils", ], ) + +py_library( + name = "keras_saved_model", + srcs = ["python/saved_model/keras_saved_model.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:lib", + "//tensorflow/python:util", + "//tensorflow/python/keras:engine", + "//tensorflow/python/saved_model:constants", + ], +) + +py_test( + name = "keras_saved_model_test", + size = "small", + srcs = ["python/saved_model/keras_saved_model_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":saved_model_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/keras", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py index b4f27a055dad7a5b95112d561cc878609a558f8d..95e1a8967b2223fd3feb112af3cbe0c5991d2d03 100644 --- a/tensorflow/contrib/saved_model/__init__.py +++ b/tensorflow/contrib/saved_model/__init__.py @@ -24,11 +24,12 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long +from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import * from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import * # pylint: enable=unused-import,widcard-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ["get_signature_def_by_key"] +_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py index 7b91622b6127413ce122c4166a18255b65365d32..e3b76bb6f34846f02ccdf623d48ddd9c5909fdce 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/__init__.py +++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py @@ -24,5 +24,6 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a969f053d3f1ded8aecd6411a62a198df48bb0 --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -0,0 +1,108 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Utility functions to save/load keras Model to/from SavedModel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import constants +from tensorflow.python.util import compat + + +def save_model(model, saved_model_path): + """Save a `tf.keras.Model` into Tensorflow SavedModel format. + + `save_model` generates such files/folders under the `saved_model_path` folder: + 1) an asset folder containing the json string of the model's + configuration(topology). + 2) a checkpoint containing the model weights. + + Note that subclassed models can not be saved via this function, unless you + provide an implementation for get_config() and from_config(). + Also note that `tf.keras.optimizers.Optimizer` instances can not currently be + saved to checkpoints. Use optimizers from `tf.train`. + + Args: + model: A `tf.keras.Model` to be saved. + saved_model_path: a string specifying the path to the SavedModel directory. + + Raises: + NotImplementedError: If the passed in model is a subclassed model. + """ + if not model._is_graph_network: + raise NotImplementedError + + # save model configuration as a json string under assets folder. + model_json = model.to_json() + assets_destination_dir = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.ASSETS_DIRECTORY)) + + if not file_io.file_exists(assets_destination_dir): + file_io.recursive_create_dir(assets_destination_dir) + + model_json_filepath = os.path.join( + compat.as_bytes(assets_destination_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) + file_io.write_string_to_file(model_json_filepath, model_json) + + # save model weights in checkpoint format. + checkpoint_destination_dir = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.VARIABLES_DIRECTORY)) + + if not file_io.file_exists(checkpoint_destination_dir): + file_io.recursive_create_dir(checkpoint_destination_dir) + + checkpoint_prefix = os.path.join( + compat.as_text(checkpoint_destination_dir), + compat.as_text(constants.VARIABLES_FILENAME)) + model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) + + +def load_model(saved_model_path): + """Load a keras.Model from SavedModel. + + load_model reinstantiates model state by: + 1) loading model topology from json (this will eventually come + from metagraph). + 2) loading model weights from checkpoint. + + Args: + saved_model_path: a string specifying the path to an existing SavedModel. + + Returns: + a keras.Model instance. + """ + # restore model topology from json string + model_json_filepath = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.ASSETS_DIRECTORY), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) + model_json = file_io.read_file_to_string(model_json_filepath) + model = model_from_json(model_json) + + # restore model weights + checkpoint_prefix = os.path.join( + compat.as_text(saved_model_path), + compat.as_text(constants.VARIABLES_DIRECTORY), + compat.as_text(constants.VARIABLES_FILENAME)) + model.load_weights(checkpoint_prefix) + return model diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..107ae1b07b777570e4124337595ceecd6e33cd0b --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Tests for saving/loading function for keras Model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import numpy as np + +from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model +from tensorflow.python import keras +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.platform import test +from tensorflow.python.training import training as training_module + + +class TestModelSavingandLoading(test.TestCase): + + def test_saving_sequential_model(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + + loaded_model = keras_saved_model.load_model(temp_saved_model) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_sequential_model_without_compile(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + + x = np.random.random((1, 3)) + ref_y = model.predict(x) + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + def test_saving_functional_model(self): + with self.test_session(): + inputs = keras.layers.Input(shape=(3,)) + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_functional_model_without_compile(self): + with self.test_session(): + inputs = keras.layers.Input(shape=(3,)) + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_with_tf_optimizer(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + loaded_model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc']) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + # test that new updates are the same with both models + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + + ref_loss = model.train_on_batch(x, y) + loss = loaded_model.train_on_batch(x, y) + self.assertAllClose(ref_loss, loss, atol=1e-05) + + ref_y = model.predict(x) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + # test saving/loading again + keras_saved_model.save_model(loaded_model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + def test_saving_subclassed_model_raise_error(self): + # For now, saving subclassed model should raise an error. It should be + # avoided later with loading from SavedModel.pb. + + class SubclassedModel(training.Model): + + def __init__(self): + super(SubclassedModel, self).__init__() + self.layer1 = keras.layers.Dense(3) + self.layer2 = keras.layers.Dense(1) + + def call(self, inp): + return self.layer2(self.layer1(inp)) + + model = SubclassedModel() + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + temp_saved_model = os.path.join(temp_dir, 'saved_model') + with self.assertRaises(NotImplementedError): + keras_saved_model.save_model(model, temp_saved_model) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py index 653c030a04c2bbc7e3ee49b9c85a781fb49de8d0..4db8dc2ca090534f2cda66bd55c30dfa389b860a 100644 --- a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py +++ b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py @@ -90,22 +90,28 @@ def overlap_and_add(signal, frame_step, name=None): raise ValueError("frame_step must be an integer. Got %s" % frame_step.dtype) - # If frame_length and frame_step are known at graph construction time, check - # frame_step is less than or equal to frame_length. - frame_step_static = tensor_util.constant_value(frame_step) - if (frame_step_static is not None and signal.shape.ndims is not None and - signal.shape[-1].value is not None and - frame_step_static > signal.shape[-1].value): - raise ValueError( - "frame_step (%d) must be less than or equal to frame_length (%d)" % ( - frame_step_static, signal.shape[-1].value)) - signal_shape = array_ops.shape(signal) # All dimensions that are not part of the overlap-and-add. Can be empty for # rank 2 inputs. outer_dimensions = signal_shape[:-2] + # If frame_length and frame_step are known at graph construction time, check + # frame_step is less than or equal to frame_length. + frame_step_static = tensor_util.constant_value(frame_step) + if (frame_step_static is not None and signal.shape.ndims is not None and + signal.shape[-1].value is not None): + if frame_step_static > signal.shape[-1].value: + raise ValueError( + "frame_step (%d) must be less than or equal to " + "frame_length (%d)" % ( + frame_step_static, signal.shape[-1].value)) + # If frame_length is equal to frame_step, there's no overlap so just + # reshape the tensor. + if frame_step_static == signal.shape[-1].value: + return array_ops.reshape(signal, array_ops.concat( + [outer_dimensions, [-1]], 0)) + signal_rank = array_ops.rank(signal) frames = signal_shape[-2] frame_length = signal_shape[-1] diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc index 7e25579070eef13682dedfcd3c9e435333f65687..6cb2c881e2428dfcac3187bf7364582e857b9879 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc @@ -51,7 +51,8 @@ std::unique_ptr CreateBinaryDecisionNodeEvaluator( InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator( const decision_trees::InequalityTest& test, int32 left, int32 right) : BinaryDecisionNodeEvaluator(left, right) { - safe_strto32(test.feature_id().id().value(), &feature_num_); + CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_)) + << "Invalid feature ID: [" << test.feature_id().id().value() << "]"; threshold_ = test.threshold().float_value(); include_equals_ = test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL; @@ -72,7 +73,9 @@ ObliqueInequalityDecisionNodeEvaluator::ObliqueInequalityDecisionNodeEvaluator( : BinaryDecisionNodeEvaluator(left, right) { for (int i = 0; i < test.oblique().features_size(); ++i) { int32 val; - safe_strto32(test.oblique().features(i).id().value(), &val); + CHECK(safe_strto32(test.oblique().features(i).id().value(), &val)) + << "Invalid feature ID: [" << test.oblique().features(i).id().value() + << "]"; feature_num_.push_back(val); feature_weights_.push_back(test.oblique().weights(i)); } @@ -97,7 +100,8 @@ int32 ObliqueInequalityDecisionNodeEvaluator::Decide( MatchingValuesDecisionNodeEvaluator::MatchingValuesDecisionNodeEvaluator( const decision_trees::MatchingValuesTest& test, int32 left, int32 right) : BinaryDecisionNodeEvaluator(left, right) { - safe_strto32(test.feature_id().id().value(), &feature_num_); + CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_)) + << "Invalid feature ID: [" << test.feature_id().id().value() << "]"; for (const auto& val : test.value()) { values_.push_back(val.float_value()); } diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index b0337c3fe9246c741b3d8a6d051ad1149ee5b408..5b54cb76b4f49bf255b3da372a7ee7bad8b66544 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -3,7 +3,7 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. -package(default_visibility = ["//tensorflow:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -37,7 +37,9 @@ tf_cuda_cc_test( "nomac", ], deps = [ + "//tensorflow/core:gpu_init", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + if_tensorrt([ @@ -83,11 +85,12 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":test_utils", ":trt_allocator", + ":trt_conversion", ":trt_logging", ":trt_plugins", ":trt_resources", - ":trt_conversion", ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -120,7 +123,6 @@ tf_cuda_library( tf_gen_op_wrapper_py( name = "trt_engine_op", - gen_locally = True, deps = [ ":trt_engine_op_op_lib", ":trt_logging", @@ -183,6 +185,8 @@ py_library( ], ) +# TODO(aaroey): this wrapper has been causing troubles of double linking, so +# either get rid of it, or split to make it contain minimum dependencies. tf_py_wrap_cc( name = "wrap_conversion", srcs = ["trt_conversion.i"], @@ -191,6 +195,7 @@ tf_py_wrap_cc( "//tensorflow/python:platform/base.i", ], deps = [ + ":test_utils", ":trt_conversion", ":trt_engine_op_kernel", "//third_party/python_runtime:headers", @@ -263,6 +268,7 @@ tf_cuda_library( ], deps = [ ":segment", + ":test_utils", ":trt_allocator", ":trt_plugins", ":trt_logging", @@ -273,7 +279,6 @@ tf_cuda_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:framework_lite", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -384,15 +389,16 @@ cuda_py_tests( "test/base_test.py", # "test/batch_matmul_test.py", # "test/biasadd_matmul_test.py", - "test/binary_tensor_weight_broadcast_test.py", - "test/concatenation_test.py", + # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation + # "test/concatenation_test.py", # Blocked by trt4 installation "test/const_broadcast_test.py", "test/multi_connection_neighbor_engine_test.py", "test/neighboring_engine_test.py", - "test/unary_test.py", - # "test/rank_two_test.py", + "test/rank_two_test.py", + # "test/unary_test.py", # Blocked by trt4 installation # "test/vgg_block_nchw_test.py", # "test/vgg_block_test.py", + "test/memory_alignment_test.py", ], additional_deps = [ ":tf_trt_integration_test_base", @@ -411,4 +417,17 @@ cc_library( srcs = ["convert/utils.cc"], hdrs = ["convert/utils.h"], copts = tf_copts(), + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_utils", + srcs = ["test/utils.cc"], + hdrs = ["test/utils.h"], + deps = [ + "//tensorflow/core:lib", + "@com_googlesource_code_re2//:re2", + ], ) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 3383f6bc9b99879a1c661a0d49e42c6f3b878f66..21ec8b0b30c595a1fad01b69bce9b16393742704 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -29,9 +30,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" -#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -195,20 +194,44 @@ tensorflow::Status ConvertCalibGraphToInferGraph( return tensorflow::Status::OK(); } -// Entry function from Python. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, int precision_mode, int minimum_segment_size, bool is_dyn_op, int max_cached_engines, std::vector cached_engine_batches) { - // optimization pass + // Create GrapplerItem. tensorflow::grappler::GrapplerItem item; item.fetch = output_names; item.graph = graph_def; - // grappler requires a virtual cluster with a proper GPU device - // in order to calculate flops>0 or fails with FATAL - // We add numbers from a Pascal card here to have flops>0 + + // TODO(aaroey): we should have used single machine cluster like the + // following, but the problem is then wrap_conversion will depend on + // direct_session and cause double linking problems. To fix this we need to + // fix or get rid of the swig dependency. Here we use VirtualCluster + // as a work around, and we need to create a session to initialize the + // underlying device before calling this method. +#if 0 + // Create single machine cluster. Note that this will create a session and + // initialize the gpu devices. + const int num_cpu_cores = + tensorflow::grappler::GetNumAvailableLogicalCPUCores(); + const int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); + VLOG(2) << "cpu_cores: " << num_cpu_cores; + VLOG(2) << "gpus: " << num_gpus; + const int timeout_s = 60 * 10; + std::unique_ptr cluster( + new tensorflow::grappler::SingleMachine( + timeout_s, num_cpu_cores, num_gpus)); + // These settings are the defaults in tensorflow/python/grappler/cluster.py. + cluster->DisableDetailedStats(true); + cluster->AllowSoftPlacement(true); + cluster->SetNumWarmupSteps(10); + TF_RETURN_IF_ERROR(cluster->Provision()); +#else + // Create virtual cluster. Grappler requires a virtual cluster with a proper + // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode. + // We add numbers from a Pascal card here to have flops>0. tensorflow::DeviceProperties device_properties; device_properties.set_type("GPU"); device_properties.mutable_environment()->insert({"architecture", "6"}); @@ -217,47 +240,43 @@ tensorflow::Status ConvertGraphDefToTensorRT( std::unique_ptr cluster( new tensorflow::grappler::VirtualCluster( {{"/GPU:0", device_properties}})); +#endif - // single machine - int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); - int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); - VLOG(2) << "cpu_cores: " << num_cpu_cores; - VLOG(2) << "gpus: " << num_gpus; + // Create RewriterConfig. tensorflow::RewriterConfig rw_cfg; - // use only const folding and layout for the time being since new optimizers - // break the graph for us + // TODO(aaroey): use only const folding and layout for the time being since + // new optimizers break the graph for trt. rw_cfg.add_optimizers("constfold"); rw_cfg.add_optimizers("layout"); - rw_cfg.set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE); + auto optimizer = rw_cfg.add_custom_optimizers(); + optimizer->set_name("TensorRTOptimizer"); + auto& parameters = *(optimizer->mutable_parameter_map()); + parameters["minimum_segment_size"].set_i(minimum_segment_size); + parameters["max_batch_size"].set_i(max_batch_size); + parameters["is_dynamic_op"].set_b(is_dyn_op); + parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes); + TF_RETURN_IF_ERROR(GetPrecisionModeName( + precision_mode, parameters["precision_mode"].mutable_s())); + parameters["maximum_cached_engines"].set_i(max_cached_engines); + if (!cached_engine_batches.empty()) { + auto list = parameters["cached_engine_batches"].mutable_list(); + for (const int batch : cached_engine_batches) { + list->add_i(batch); + } + } + + // Run optimizer. tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); - tensorflow::GraphDef gdef; - TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef)); - item.graph = gdef; - - // AJ refactoring shape inference through grappler/GraphProperties. - tensorflow::grappler::GraphProperties static_graph_properties(item); - TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - // Build full graph - ConversionParams cp; - cp.input_graph_def = &gdef; - cp.output_names = &output_names; - cp.max_batch_size = max_batch_size; - cp.output_graph_def = new_graph_def; - cp.precision_mode = precision_mode; - cp.is_dyn_op = is_dyn_op; - cp.max_cached_engines = max_cached_engines; - cp.cached_engine_batches = cached_engine_batches; - cp.minimum_segment_size = minimum_segment_size; - cp.graph_properties = &static_graph_properties; - cp.max_workspace_size_bytes = max_workspace_size_bytes; + TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def)); + if (VLOG_IS_ON(5)) { std::fstream f; f.open("TRTConversionInput.pb", std::fstream::out | std::fstream::binary | std::fstream::trunc); - f << gdef.SerializeAsString(); + f << new_graph_def->SerializeAsString(); f.close(); } - return ConvertAfterShapes(cp); + return Status::OK(); } // Function to get subsegment information structure. @@ -268,11 +287,10 @@ tensorflow::Status GetEngineInfo( const std::unordered_map& node_map, const std::vector& reverse_topo_order, EngineInfo* info) { - std::vector subgraph_node_ids; + std::vector subgraph_node_ids; // Topologically sorted node ids. + std::set subgraph_node_names = segment_nodes; std::set added_const_node_ids; // Used to prevent double insertion. std::set segment_devices; - int input_port = 0; - int output_port = 0; // Map from src_node_name+port to the unique port numbers of the TRT op, where // the src_node_name is the name of the source node of the input/output @@ -280,13 +298,12 @@ tensorflow::Status GetEngineInfo( // input/output edges must be in different split of the graph. // TODO(aaroey): consider using node id and port instead. // TODO(aaroey): using topo order instead of reverting reverse topo order. - std::unordered_map created_edges; + std::unordered_map input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { const auto& node_name = (*it)->name(); - if (segment_nodes.count(node_name) == 0) continue; - auto node = node_map.at(node_name); + auto node = *it; auto node_device = node->requested_device(); if (!node_device.empty()) { segment_devices.insert(node_device); @@ -299,64 +316,93 @@ tensorflow::Status GetEngineInfo( } } const int node_id = node->id(); + subgraph_node_ids.push_back(node_id); + // Create input connections. for (const auto edge : node->in_edges()) { auto input_node = edge->src(); - if (segment_nodes.count(input_node->name()) == 0 && - !edge->IsControlEdge() && !input_node->IsSource()) { - // Add constant input node into the segment. We don't care if it has - // other output edges going into other engines or TF nodes. Since we add - // it only to the subsegment node list, not the subsegment itself, it - // won't be removed from the graph. If it doesn't have any edges, TF - // will prune it out. - if (input_node->type_string() == "Const") { - if (added_const_node_ids.count(input_node->id()) == 0) { - added_const_node_ids.insert(input_node->id()); - subgraph_node_ids.push_back(input_node->id()); - } + if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control input. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else if (input_node->type_string() == "Const") { + // Add constant data input nodes into the segment graphdef (thus also in + // the engine). We don't care if it has other output edges going into + // other engines or TF nodes. Since we add it only to the segment + // graphdef, not the segment itself, it won't be removed from the graph. + // If it doesn't have any edges, TF will prune it out. + // + // Note that the segmenter already ensure that the constant data input + // is valid and suppported by the engine. + if (!added_const_node_ids.insert(input_node->id()).second) { + // Already added before. + continue; + } + VLOG(1) << "Adding const node " << input_node->name(); + QCHECK(subgraph_node_names.insert(input_node->name()).second); + // Since we already add (duplicate) the const input node to the segment + // graphdef, it's now not a data dependency any more, but to make the + // dependency correct we still add a control dependency. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else { + // Non-const data input. + int port = Graph::kControlSlot - 1; + // Use the source non-segment node name/port as key. + const string s = StrCat(input_node->name(), ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + if (input_to_engine_port.count(s)) { + port = input_to_engine_port.at(s); } else { - string s(input_node->name()); - StrAppend(&s, ":", edge->src_output()); - VLOG(1) << "Input edge = " << s; - int port = input_port; - if (created_edges.count(s)) { - port = created_edges.at(s); - } else { - created_edges.insert({s, port}); - input_port++; - } - info->connections.emplace_back(input_node->name(), input_node->id(), - edge->src_output(), node_name, node_id, - edge->dst_input(), true, port); + port = input_to_engine_port.size(); + input_to_engine_port.insert({s, port}); } + info->connections.emplace_back( + input_node->name(), input_node->id(), edge->src_output(), node_name, + node_id, edge->dst_input(), /*input_edge=*/true, port); } } - // We need to add possible const input nodes before adding this node in - // order to keep the topological order. - subgraph_node_ids.push_back(node_id); + // Create output connections. for (const auto edge : node->out_edges()) { auto output_node = edge->dst(); - if (segment_nodes.count(output_node->name()) == 0 && - !edge->IsControlEdge() && !output_node->IsSink()) { - string s(node_name); - StrAppend(&s, ":", edge->src_output()); + if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control output. + info->connections.emplace_back(output_node->name(), output_node->id(), + node_name, node_id, + /*input_edge=*/false); + } else { + // Data output. + int port = Graph::kControlSlot - 1; + // Use the source segment node name/port as key. + const string s = StrCat(node_name, ":", edge->src_output()); VLOG(1) << "Output edge = " << s; - int port = output_port; - if (created_edges.count(s)) { - port = created_edges.at(s); + if (output_to_engine_port.count(s)) { + port = output_to_engine_port.at(s); } else { - created_edges.insert({s, port}); - output_port++; + port = output_to_engine_port.size(); + output_to_engine_port.insert({s, port}); } - info->connections.emplace_back(output_node->name(), output_node->id(), - edge->dst_input(), node_name, node_id, - edge->src_output(), false, port); + info->connections.emplace_back( + output_node->name(), output_node->id(), edge->dst_input(), + node_name, node_id, edge->src_output(), /*input_edge=*/false, port); } } - } + } // For each segment node in topological order. + // Construct the const nodes first. + subgraph_node_ids.insert(subgraph_node_ids.begin(), + added_const_node_ids.begin(), + added_const_node_ids.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_ids, &info->connections, - &info->segment_graph_def, &info->engine_name)); + g, graph_properties, subgraph_node_names, subgraph_node_ids, + &info->connections, &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -366,94 +412,137 @@ tensorflow::Status GetEngineInfo( << "but this shouldn't have happened"; info->device = *segment_devices.begin(); } else { - VLOG(1) << "Segment devices size is 0"; + LOG(ERROR) << "Can't find a device placement for the op!"; } return Status::OK(); } -// Function to insert a TRT node into the graph. The graph is not modified if -// the returned status is not ok. -// 'alloc' is only used for creating static engine. -tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, - const std::vector& infos, int pos, +// Helper function to update edge connection from the removed node to the +// engine node. If an outside node is gone, it must have been absorbed into +// an engine node. Find the engine node. +void UpdateToEngineNode(const std::vector& infos, + const size_t my_engine_id, + const std::vector& engine_nodes, + const bool is_input_edge, const string& node_name, + tensorflow::Node** node, int* port) { + for (size_t t = 0; t < infos.size(); ++t) { + if (t == my_engine_id) { + continue; + } + const auto& info = infos.at(t); + for (const auto& eng_conn : info.connections) { + // If the connection being updated is an input connection, the source of + // the connection must be an output connection of another engine. And vise + // versa. + if (is_input_edge == eng_conn.is_input_edge) continue; + if (eng_conn.inside_node_name == node_name && + eng_conn.inside_port == *port) { + *node = CHECK_NOTNULL(engine_nodes[t]); + QCHECK_EQ(info.engine_name, (**node).name()) + << "Engine name mismatch: " << info.engine_name << " vs " + << (**node).name(); + *port = eng_conn.port_number; + return; + } + } + } + LOG(FATAL) << "Node " << (**node).name() << " not found in any engine."; +} + +// Function to insert a TRT engine node into the graph. +// Create engine nodes in the following way: +// 1. Each invocation of CreateTRTNode creates an engine node for infos[pos] +// 2. When an engine node is created, add it into the graph with necessary +// re-wiring. +// 2.1. If the outside connected node is existing, connect the engine +// node to it. +// 2.2. If the outside connected node is gone, it must have been absorted +// into another engine node (which was processed before the processing +// one). Connect to the pre-existing engine node instead. +// 3. In this way, we ensure the graph is topologically sort-able after each +// invocation of CreateTRTNode(). +tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, + int max_batch_size, tensorflow::Graph* graph, nvinfer1::IGpuAllocator* alloc, - int max_batch_size) { + std::vector* engine_nodes) { const auto& info = infos.at(pos); + TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail"); std::vector output_shape_protos; std::vector input_shape_protos; std::vector input_shapes; std::vector inputs; + std::vector input_nodes; + std::vector control_input_nodes; + std::unordered_set control_input_names; std::vector out_types; - VLOG(1) << "Processing " << info.engine_name; - // Update the shape and data types of input/output nodes, and find all unique - // inputs. + VLOG(1) << "Processing " << info.engine_name; + // Collect needed info for creating the engine node in the graph for (const auto& conn : info.connections) { - if (!conn.is_input_edge) { - // Set the shapes and data types of output edge. - tensorflow::TensorShapeProto out_shape; - // shape of the output node inside segment - conn.inside_shape.AsProto(&out_shape); - if (output_shape_protos.size() <= conn.port_number) { - output_shape_protos.resize(conn.port_number + 1); - out_types.resize(conn.port_number + 1); + // Control edges + if (conn.is_control_edge()) { + // Skip control outputs for now. control output info are not needed for + // node creation and will be processed later. + if (!conn.is_input_edge) continue; + + // Rewrire control input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = tensorflow::Graph::kControlSlot; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + QCHECK_EQ(Graph::kControlSlot, port); } - output_shape_protos.at(conn.port_number) = out_shape; - out_types.at(conn.port_number) = conn.connection_type; - continue; - } - - // Set the shapes and data types of input edge. - tensorflow::TensorShapeProto in_shape; - conn.outside_shape.AsProto(&in_shape); - if (input_shape_protos.size() <= conn.port_number) { - input_shape_protos.resize(conn.port_number + 1); - input_shapes.resize(conn.port_number + 1); - } - input_shape_protos.at(conn.port_number) = in_shape; - input_shapes.at(conn.port_number) = conn.outside_shape; - - string input_node = conn.outside_node_name; - int input_port = conn.outside_port; - bool found_engine = false; - // Rewire the inputs to other engines if they contain original input node. - // Note that we use the information of the engine here, not the information - // of the created TRT nodes, so we're able to find all the connections to - // any other engines beforehand. - for (size_t t = 0; t < infos.size(); ++t) { - if (t == pos) continue; - auto& engine_info = infos.at(t); - for (const auto& eng_conn : engine_info.connections) { - if (eng_conn.is_input_edge) continue; - if (eng_conn.inside_node_name == input_node) { - input_node = engine_info.engine_name; - if (eng_conn.inside_port == input_port) { - input_port = eng_conn.port_number; - found_engine = true; - break; - } - } + if (!control_input_names.insert(input_node->name()).second) { + continue; } - if (found_engine) break; - } - VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> " - << info.engine_name << ":" << inputs.size(); - // Skip duplicate inputs. - // TODO(aaroey): use std::find instead. GetEngineInfo already remove - // duplicate connections, so here we should never find any duplicate? - bool new_input = true; - for (const auto& inp : inputs) { - if (inp.node == input_node && inp.index == input_port) { - new_input = false; - break; + control_input_nodes.push_back(input_node); + VLOG(1) << "Engine Control Input " << input_node->name() << " -> " + << info.engine_name; + } else { + // Data edges + if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. + tensorflow::TensorShapeProto out_shape; + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); + if (output_shape_protos.size() <= conn.port_number) { + output_shape_protos.resize(conn.port_number + 1); + out_types.resize(conn.port_number + 1); + } + output_shape_protos.at(conn.port_number) = out_shape; + out_types.at(conn.port_number) = conn.connection_type; + } else { + // Set the shapes and data types of input edge. + tensorflow::TensorShapeProto in_shape; + conn.outside_shape.AsProto(&in_shape); + if (input_shape_protos.size() <= conn.port_number) { + input_shape_protos.resize(conn.port_number + 1); + input_shapes.resize(conn.port_number + 1); + } + input_shape_protos.at(conn.port_number) = in_shape; + input_shapes.at(conn.port_number) = conn.outside_shape; + + // Rewrire data input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + } + if (std::find_if( + std::begin(inputs), std::end(inputs), + [input_node, &port](const NodeDefBuilder::NodeOut& inp) { + return inp.node == input_node->name() && inp.index == port; + }) == std::end(inputs)) { + inputs.emplace_back(input_node->name(), port, conn.connection_type); + input_nodes.push_back(CHECK_NOTNULL(input_node)); + VLOG(1) << "Engine Input " << input_node->name() << ":" << port + << " -> " << info.engine_name << ":" << inputs.size() - 1; + } } } - if (new_input) { - inputs.emplace_back(input_node, input_port, conn.connection_type); - } } - - // Build the engine and get its serialized representation. string segment_string; if (info.engine_type == EngineInfo::EngineType::TRTStatic || info.precision_mode == INT8MODE) { @@ -485,21 +574,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // TODO(aaroey): use enum instead, and add a helper method to do the // conversion. string prec_string; - switch (info.precision_mode) { - case FP32MODE: - prec_string = "FP32"; - break; - case FP16MODE: - prec_string = "FP16"; - break; - case INT8MODE: - prec_string = "INT8"; - if (!TRTResourceManager::instance()->getManager("TRTCalibration")) { - LOG(ERROR) << "Failed to construct calibration storage"; - } - break; - default: - return tensorflow::errors::OutOfRange("Unknown precision mode"); + TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); + if (info.precision_mode == INT8MODE && + !TRTResourceManager::instance()->getManager("TRTCalibration")) { + LOG(ERROR) << "Failed to construct calibration storage"; } tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); if (!info.device.empty()) node_builder.Device(info.device); @@ -511,6 +589,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, VLOG(1) << ins; } node_builder.Input(inputs); + for (const string& c : control_input_names) { + node_builder.ControlInput(c); + } + if (info.engine_type == EngineInfo::EngineType::TRTStatic && info.cached_engine_batches.size()) { LOG(WARNING) << "Cached engine batches are ignored for static engines"; @@ -539,34 +621,55 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // Up until this point, graph is not modified. If we return !status.ok() from // here, this segment will be skipped + // TODO(aaroey): let it return proper error status for the following logic + // instead of checking fail. tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + (*engine_nodes)[pos] = engine_node; if (!status.ok()) { LOG(ERROR) << "Adding node failed " << status; return status; } + // Add control input and input edges to the engine node. + for (const auto in : control_input_nodes) { + VLOG(1) << "Connecting control edge from " << in->name() << " to " + << engine_node->name(); + graph->AddControlEdge(in, engine_node); + } + VLOG(1) << "input_nodes size = " << input_nodes.size(); + for (int i = 0; i < input_nodes.size(); ++i) { + Node* n = CHECK_NOTNULL(input_nodes[i]); + const auto& in = inputs[i]; + VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index + << " to " << engine_node->name() << ":" << i; + graph->AddEdge(n, in.index, engine_node, i); + } + // Updates the inputs of output edges destination nodes, and point them to the // engine node. for (auto& conn : info.connections) { - if (conn.is_input_edge) continue; - VLOG(1) << " Updating DBG " << engine_node->name() << " out_port " - << conn.port_number << " out_id " << conn.outside_id - << " name=" << conn.outside_node_name; - auto dst_node = graph->FindNodeId(conn.outside_id); - // dst_node can only be removed if it is an input node of another engine. - // In this case, other engines input edge is updated in nodedef to point to - // this engine. Even though edge doesn't exists in the graph, when it is - // deserialized again, correct edges will be constructed. This is a problem - // of graph->AddNode(). - if (!dst_node) continue; + if (conn.is_input_edge) { + continue; + } + tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!output_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false, + conn.outside_node_name, &output_node, &port); + } VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number - << " to " << dst_node->name() << ":" << conn.outside_port; - auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node, - conn.outside_port); - CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":" - << conn.port_number << " -> " << dst_node->name() << ":" - << conn.outside_port; + << " to " << output_node->name() << ":" << port; + if (conn.is_control_edge()) { + QCHECK_EQ(Graph::kControlSlot, port); + graph->AddControlEdge(engine_node, output_node); + } else { + auto new_edge = + graph->AddEdge(engine_node, conn.port_number, output_node, port); + QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name() + << ":" << conn.port_number << " -> " + << output_node->name() << ":" << conn.outside_port; + } } - return status; + return Status::OK(); } // Function to construct a funcdef from the segment and add it to the graph. @@ -666,72 +769,36 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( } std::pair GetDeviceAndAllocator( - ConversionParams& params, EngineInfo& engine) { + const ConversionParams& params, const EngineInfo& engine) { int cuda_device_id = -1; - auto check_device_id = [](int tfid) -> int { - tensorflow::TfGpuId tf_gpu_id(tfid); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); - if (s.ok()) { - VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device " - << cuda_gpu_id.value(); - return cuda_gpu_id.value(); - } - VLOG(2) << "TF GPU with id " << tfid << " do not exist " << s; - return -1; - }; tensorflow::Allocator* dev_allocator = nullptr; - // we need to us PM here since in python path there is no way to get - // to allocators. - // TODO(sami): when grappler devices become available else path will not be - // necessary - auto pm = tensorflow::GPUProcessState::singleton(); - if (params.cluster) { // get allocator - tensorflow::Device* device = nullptr; - if (params.cluster->GetDeviceSet()) { - device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device); + if (params.cluster) { + std::vector devices; + if (!engine.device.empty() && params.cluster->GetDeviceSet()) { + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) && + parsed_name.has_id) { + params.cluster->GetDeviceSet()->FindMatchingDevices(parsed_name, + &devices); + } } - if (device) { + if (!devices.empty()) { + if (devices.size() > 1) { + string msg = "Found multiple matching devices using name '"; + StrAppend(&msg, engine.device, "': "); + for (auto d : devices) StrAppend(&msg, d->name(), ", "); + StrAppend(&msg, ". Will get the allocator from first one."); + LOG(WARNING) << msg; + } tensorflow::AllocatorAttributes alloc_attr; - dev_allocator = device->GetAllocator(alloc_attr); - VLOG(1) << "Using allocator " << dev_allocator->Name(); + cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id; + dev_allocator = devices[0]->GetAllocator(alloc_attr); + VLOG(1) << "Using allocator " << dev_allocator->Name() + << " and cuda_device_id " << cuda_device_id; } else { LOG(WARNING) << "Cluster is set but device '" << engine.device << "' is not found in the cluster"; } - } else { // cluster not found, possibly a python call - VLOG(1) << "Cluster is not set, probably called from python"; - int found_device = 0; - bool try_gpu_ids = true; - // if device is set, try to find the device. Might be a problem for multi - // host case but TensorRT do not support multi host setups yet. - if (!engine.device.empty()) { - DeviceNameUtils::ParsedName parsed_name; - if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) { - cuda_device_id = parsed_name.has_id ? parsed_name.id : -1; - } - try_gpu_ids = !parsed_name.has_id; - } - if (try_gpu_ids) { - while (found_device < 100) { - cuda_device_id = check_device_id(found_device); - if (cuda_device_id >= 0) break; - found_device++; - } - } - if (found_device == 100) { - LOG(ERROR) << " Can't find a GPU device to work with. Please " - "instantiate a session to initialize devices"; - return std::make_pair(cuda_device_id, dev_allocator); - } - LOG(WARNING) - << "Can't determine the device, constructing an allocator at device " - << found_device; - tensorflow::GPUOptions gpuoptions; - // this will be a noop if device is already initialized - gpuoptions.set_allow_growth(true); - tensorflow::TfGpuId tf_gpu_id(found_device); - dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); } return std::make_pair(cuda_device_id, dev_allocator); } @@ -824,6 +891,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); } VLOG(1) << "Current cuda device is " << old_cuda_device; + std::vector engine_nodes; + engine_nodes.resize(engine_segments.size()); for (int i = 0; i < engine_segments.size(); ++i) { auto& engine = engine_segments.at(i); // Partition the workspace size by the average of node ratio and segment @@ -847,19 +916,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } cudaSetDevice(cuda_device_id); - auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(), - params.max_batch_size); + auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, + &graph, alloc.get(), &engine_nodes); // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. + const string msg = StrCat("Engine ", engine.engine_name, + " creation for segment ", i, ", composed of ", + converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { + LOG(INFO) << msg << " succeeded."; for (auto node_name : converted_segments.at(i).first) { graph.RemoveNode(node_map.at(node_name)); } } else { // Graph is not modified. - LOG(WARNING) << "Engine creation for segment " << i << ", composed of " - << converted_segments.at(i).first.size() - << " nodes failed: " << status << ". Skipping..."; + LOG(WARNING) << msg << " failed: " << status << ". Skipping..."; } } cudaSetDevice(old_cuda_device); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 9d881eda909a8fe9abe20b1696411dfe60b663d1..7aec963aa3d8cfc6e8b9284d692654b716080b41 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -2690,7 +2691,7 @@ tensorflow::Status ConvertGraphDefToEngine( // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { string node_name = node_def.name(); - VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op(); + VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op(); if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && (node_def.op() == "Placeholder")) { int32 slot_number = -1; @@ -2791,6 +2792,7 @@ tensorflow::Status ConvertGraphDefToEngine( tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, + const std::set& subgraph_node_names, const std::vector& subgraph_node_ids, // In topological order std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope) { @@ -2799,6 +2801,7 @@ tensorflow::Status ConvertSegmentToGraphDef( // nodes in the segment graphdef. for (size_t i = 0; i < connections->size(); ++i) { auto& connection = connections->at(i); + if (connection.is_control_edge()) continue; auto outside_node = graph->FindNodeId(connection.outside_id); if (!outside_node) { // This should never happen, unless the original graph is problematic. @@ -2812,13 +2815,13 @@ tensorflow::Status ConvertSegmentToGraphDef( GetInputProperties(graph_properties, graph->FindNodeId(connection.outside_id), connection.outside_port, &partial_shape, &dtype); - + connection.outside_shape = partial_shape; } else { GetOutputProperties(graph_properties, graph->FindNodeId(connection.outside_id), connection.outside_port, &partial_shape, &dtype); + connection.inside_shape = partial_shape; } - connection.outside_shape = partial_shape; connection.connection_type = dtype; // Add dummy input/output nodes to the segment graphdef. @@ -2871,12 +2874,12 @@ tensorflow::Status ConvertSegmentToGraphDef( old_to_new_id_map[node_id] = segment_def->node_size(); auto snode = segment_def->add_node(); snode->CopyFrom(node->def()); - VLOG(1) << "Copying " << snode->name() << " to subgraph"; + VLOG(2) << "Copying " << snode->name() << " to subgraph"; } // Update the inputs of the new input nodes to point to placeholder nodes. for (int i = 0; i < connections->size(); ++i) { auto& connection = connections->at(i); - if (!connection.is_input_edge) continue; + if (connection.is_control_edge() || !connection.is_input_edge) continue; auto snode = segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); const string placeholder_name = @@ -2886,6 +2889,39 @@ tensorflow::Status ConvertSegmentToGraphDef( << placeholder_name; snode->set_input(connection.inside_port, placeholder_name); } + // Remove control inputs that are not inside the segment. + for (int i = 0; i < segment_def->node_size(); ++i) { + auto snode = segment_def->mutable_node(i); + const int input_size = snode->input_size(); + int input_idx = 0; + int actual_input_idx = 0; + while (input_idx < input_size) { + TensorId input = ParseTensorName(snode->input(input_idx)); + if (!subgraph_node_names.count( + string(input.first.data(), input.first.size())) && + !str_util::StartsWith(input.first, kInputPHName)) { + if (input.second == Graph::kControlSlot) { + VLOG(1) << "... removing control inputs " << input.first + << " from subgraph."; + ++input_idx; + continue; + } else { + return tensorflow::errors::InvalidArgument( + "Found non control input outside the segment that is not an " + "engine connection to ", + snode->name(), ": ", input.first); + } + } + if (actual_input_idx != input_idx) { + snode->set_input(actual_input_idx, snode->input(input_idx)); + } + ++input_idx; + ++actual_input_idx; + } + for (int remove = input_size - actual_input_idx; remove > 0; --remove) { + snode->mutable_input()->RemoveLast(); + } + } *common_scope = local_scope; VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; return tensorflow::Status::OK(); @@ -2900,7 +2936,7 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { nvinfer1::DataType trt_dtype; Status status = ValidateInputProperties(shape, dtype, &trt_dtype); if (!status.ok()) { - VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() + VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() << ": " << status; return false; } @@ -2908,7 +2944,7 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { #if NV_TENSORRT_MAJOR == 3 // TRT 3.x only support 4 dimensional input tensor. if (shape.dims() != 4 && in_edge->src()->type_string() != "Const") { - VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() + VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() << " which has an input at port " << in_edge->dst_input() << " with #dim!=4 and is not a const: " << shape; return false; @@ -2920,7 +2956,7 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { if (out_edge->IsControlEdge()) return true; if (out_edge->src()->type_string() == "Const") { - VLOG(2) << "--> Need to remove output node " << out_edge->src()->name() + VLOG(1) << "--> Need to remove output node " << out_edge->src()->name() << " which is a Const."; return false; } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 6ae60ec352587feb8b26d6fcc69c907a5f145760..a60253740fe0b27dcd9c20618d6d05aa7001a1a1 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -36,16 +36,12 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -static const char* kInputPHName = "InputPH_"; -static const char* kOutputPHName = "OutputPH_"; +static const char* kInputPHName = "TensorRTInputPH_"; +static const char* kOutputPHName = "TensorRTOutputPH_"; namespace convert { -// TODO(aaroey): use an enum instead. -const int FP32MODE = 0; -const int FP16MODE = 1; -const int INT8MODE = 2; - struct EngineConnection { + // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, const string& inside, int in_id, int in_port, bool input_edge, int port) @@ -58,21 +54,35 @@ struct EngineConnection { is_input_edge(input_edge), port_number(port) {} + // Constructs a control edge. + EngineConnection(const string& outside, int out_id, const string& inside, + int in_id, bool input_edge) + : outside_node_name(outside), + outside_id(out_id), + outside_port(Graph::kControlSlot), + inside_node_name(inside), + inside_id(in_id), + inside_port(Graph::kControlSlot), + is_input_edge(input_edge), + port_number(Graph::kControlSlot) {} + + bool is_control_edge() const { return port_number == Graph::kControlSlot; } + const string outside_node_name; const int outside_id; const int outside_port; - tensorflow::PartialTensorShape outside_shape; + tensorflow::PartialTensorShape outside_shape; // Only set for input edge. const string inside_node_name; const int inside_id; const int inside_port; - tensorflow::PartialTensorShape inside_shape; + tensorflow::PartialTensorShape inside_shape; // Only set for output edge. tensorflow::DataType connection_type; - bool is_input_edge; + const bool is_input_edge; - // The port number of the TRT node connecting to this edge. - int port_number; + // The port number of the TRT node connected with this edge. + const int port_number; }; struct EngineInfo { @@ -85,7 +95,9 @@ struct EngineInfo { string device; tensorflow::GraphDef segment_graph_def; - // The segment nodes that are on one side of the edges are topological sorted. + // Non-control input connections inside this vector are sorted in a way such + // that, the segment nodes connecting to them are topological sorted. + // In addition, for non-control connections, there must be no duplicates. std::vector connections; enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; @@ -101,6 +113,7 @@ struct EngineInfo { // (OutputPH_*). This function needs to be called before TensorRT nodes // inserted in order to correctly get sizes from the original graph. // +// - subgraph_node_names: the node names of the subgraph. // - subgraph_node_ids: the node ids of the subgraph, must be sorted in // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be @@ -110,6 +123,7 @@ struct EngineInfo { tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, + const std::set& subgraph_node_names, const std::vector& subgraph_node_ids, std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope); diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index 044c736c03e0dcad0d27d6b9ad9d244816596536..f33f2cc4d68f5ac10eafeb744f8162bfca0abfab 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stacktrace.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -189,9 +190,6 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::grappler::Cluster* cluster, const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) { VLOG(1) << "Called TRTOptimization Pass " << name_; - if (VLOG_IS_ON(1)) { - PrintDebugInfo(cluster, item); - } // This is a hack to workaround optimizer issue. MetaOptimizer calls // optimization passes on function objects as well, we should not modify // generated funcdefs! This is fragile but we don't have any other option @@ -203,6 +201,10 @@ tensorflow::Status TRTOptimizationPass::Optimize( *optimized_graph = item.graph; return tensorflow::Status::OK(); } + if (VLOG_IS_ON(1)) { + VLOG(2) << CurrentStackTrace(); + PrintDebugInfo(cluster, item); + } int max_dim = -1; if (item.feed.size()) { for (const auto& f : item.feed) { diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/contrib/tensorrt/convert/utils.cc index 24591cf84b244a548441b876228c5819d82ddefb..e7a1febb8c076891596741fe30721e7acca15a73 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.cc +++ b/tensorflow/contrib/tensorrt/convert/utils.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + namespace tensorflow { namespace tensorrt { @@ -24,12 +27,43 @@ bool IsGoogleTensorRTEnabled() { // safely write code that uses tensorrt conditionally. E.g. if it does not // check for for tensorrt, and user mistakenly uses tensorrt, they will just // crash and burn. -#ifdef GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT return true; #else return false; #endif } +Status GetPrecisionModeName(const int precision_mode, string* name) { + switch (precision_mode) { + case FP32MODE: + *name = "FP32"; + break; + case FP16MODE: + *name = "FP16"; + break; + case INT8MODE: + *name = "INT8"; + break; + default: + return tensorflow::errors::OutOfRange("Unknown precision mode"); + } + return Status::OK(); +} + +Status GetPrecisionMode(const string& name, int* precision_mode) { + if (name == "FP32") { + *precision_mode = FP32MODE; + } else if (name == "FP16") { + *precision_mode = FP16MODE; + } else if (name == "INT8") { + *precision_mode = INT8MODE; + } else { + return tensorflow::errors::InvalidArgument("Invalid precision mode name: ", + name); + } + return Status::OK(); +} + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h index 8b5f4d614a9c1f849f0aec9df42100bb4126b439..0592f31462af2b20f3a13fe5119e89c2ba42dd8a 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.h +++ b/tensorflow/contrib/tensorrt/convert/utils.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "tensorflow/core/lib/core/status.h" + namespace tensorflow { namespace tensorrt { @@ -33,6 +35,15 @@ using TrtUniquePtrType = std::unique_ptr>; bool IsGoogleTensorRTEnabled(); +// TODO(aaroey): use an enum instead. +const int FP32MODE = 0; +const int FP16MODE = 1; +const int INT8MODE = 2; + +Status GetPrecisionModeName(const int precision_mode, string* name); + +Status GetPrecisionMode(const string& name, int* precision_mode); + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 646d62483f577436c1396c690da492eaea10966e..2b42d81f475189f74a934c3aeed7d7fc34d4eb53 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -45,11 +46,11 @@ using ::tensorflow::strings::StrCat; // Helps simultaneous execution of native and TRT engines. class AsyncHelper : public tensorflow::core::RefCounted { public: - AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; } + AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; } ~AsyncHelper() override { done_(); } private: - tensorflow::AsyncOpKernel::DoneCallback done_; + AsyncOpKernel::DoneCallback done_; }; #define TYPECASE(dt, X, Y) \ @@ -122,15 +123,9 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("calibration_data", &calibration_data)); OP_REQUIRES_OK(context, context->GetAttr("segment_funcdef_name", &funcdef_name_)); - if (precision_string == "FP32") { - precision_mode_ = convert::FP32MODE; - } else if (precision_string == "FP16") { - precision_mode_ = convert::FP16MODE; - } else if (precision_string == "INT8") { - precision_mode_ = convert::INT8MODE; - } + OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); calibration_mode_ = - (precision_mode_ == convert::INT8MODE && calibration_data.size() == 0); + (precision_mode_ == INT8MODE && calibration_data.size() == 0); if (calibration_data.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); @@ -152,7 +147,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) } } -void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, +void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper) { if (!calibration_mode_) { VLOG(1) << "Executing native engine"; @@ -179,7 +174,7 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment " << name(); lib->Run(opts, native_func_, inputs, outputs, - [ctx, outputs, helper](const tensorflow::Status& s) { + [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); VLOG(1) << "Native Segment completed"; if (!s.ok()) { @@ -189,11 +184,13 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } + test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"), + "done"); delete outputs; }); } -void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, +void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper) { helper->Ref(); tensorflow::core::ScopedUnref sc(helper); @@ -234,11 +231,12 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, ->implementation() ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); + test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done"); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); } -int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) { +int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { int num_batch = ctx->input(0).shape().dim_size(0); int smallest_engine = 0; for (const auto i : cached_engine_batches_) { @@ -254,21 +252,20 @@ int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) { cached_engine_batches_.push_back(num_batch); VLOG(1) << "Running with batch size " << num_batch; } else { - string s("Engine buffer is full. buffer limit= "); - StrAppend(&s, max_cached_engines_, ", current entries= "); - for (auto i : cached_engine_batches_) StrAppend(&s, i, ", "); - StrAppend(&s, "Requested batch= ", num_batch); - LOG(ERROR) << s; - ctx->SetStatus(tensorflow::errors::ResourceExhausted( - "Requested batch size is not available and engine cache is full")); + string msg = + StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, + ", current entries="); + for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); + StrAppend(&msg, " requested batch=", num_batch); + LOG(WARNING) << msg; return -1; } } return smallest_engine; } -void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, - tensorflow::AsyncOpKernel::DoneCallback done) { +void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, + AsyncOpKernel::DoneCallback done) { auto helper = new AsyncHelper(done); tensorflow::core::ScopedUnref sc(helper); if (calibration_mode_) { @@ -276,32 +273,54 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, return; } const int smallest_engine = GetEngineBatch(ctx); - if (smallest_engine < 0) return; // GetEngineBatch already set the status. + if (smallest_engine < 0) { + LOG(WARNING) << "Failed to get engine batch, running native segment for " + << name(); + ExecuteNativeSegment(ctx, helper); + return; + } const int num_batch = ctx->input(0).shape().dim_size(0); auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); auto& trt_engine_ptr = engine_ctx_pair.first; if (!trt_engine_ptr) { LOG(WARNING) << "Engine retrieval for batch size " << num_batch - << " failed Running native segment"; + << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } + const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), + engine_ctx_pair.second.get()); + if (retry) { + LOG(WARNING) << "Failed to execute engine, " + << "retrying with native segment for " << name(); + ExecuteNativeSegment(ctx, helper); + return; + } +} +bool TRTEngineOp::ExecuteTrtEngine( + OpKernelContext* ctx, const int num_batch, + nvinfer1::ICudaEngine* trt_engine_ptr, + nvinfer1::IExecutionContext* trt_execution_context_ptr) { + const bool kRetry = true; const int num_binding = ctx->num_inputs() + ctx->num_outputs(); std::vector buffers(num_binding); for (int i = 0; i < ctx->num_inputs(); i++) { - const string inp_name = StrCat(kInputPHName, i); + const string input_name = StrCat(kInputPHName, i); const size_t binding_index = - trt_engine_ptr->getBindingIndex(inp_name.c_str()); + trt_engine_ptr->getBindingIndex(input_name.c_str()); + if (binding_index == -1) { + LOG(ERROR) << "Input node not found, at " << input_name; + return kRetry; + } const Tensor& input_tensor = ctx->input(i); const TensorShape& input_shape = input_tensor.shape(); if (num_batch != input_shape.dim_size(0)) { - LOG(ERROR) << "input data inconsistent batch size"; - ctx->SetStatus(tensorflow::errors::FailedPrecondition( - "Different batch sizes between input tensors")); - return; + LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch + << " vs " << input_shape.dim_size(0); + return kRetry; } auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { @@ -310,14 +329,10 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, break; case nvinfer1::DataType::kHALF: LOG(ERROR) << "FP16 inputs are not supported yet!"; - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "FP16 inputs are not supported!")); - return; + return kRetry; case nvinfer1::DataType::kINT8: LOG(ERROR) << "INT8 inputs are not supported yet!"; - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "INT8 inputs are not supported!")); - return; + return kRetry; #if NV_TENSORRT_MAJOR > 3 case nvinfer1::DataType::kINT32: buffers[binding_index] = (void*)(input_tensor.flat().data()); @@ -325,9 +340,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, #endif default: LOG(ERROR) << "Unknown TRT data type: " << int(dtype); - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "Unknown output TRT data type! ", static_cast(dtype))); - return; + return kRetry; } } @@ -344,20 +357,23 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; - OP_REQUIRES_OK( - ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(), - &output_shape)); + auto status = TensorShapeUtils::MakeShape( + trt_shape.data(), trt_shape.size(), &output_shape); + if (!status.ok()) { + LOG(ERROR) << "Failed to get output shape: " << status; + return kRetry; + } } else { - LOG(ERROR) << "output node not found, at " << output_name; - ctx->SetStatus(tensorflow::errors::Internal("output ", output_name, - " couldn't be found!")); - return; + LOG(ERROR) << "Output node not found, at " << output_name; + return kRetry; } auto status = ctx->allocate_output(i, output_shape, &output_tensor); if (!status.ok()) { LOG(ERROR) << "Allocating output failed with " << status; ctx->SetStatus(status); - return; + // Do not retry since we cannot allocate the same output twice. + // TODO(aaroey): ideally we should retry, fix this. + return !kRetry; } auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { @@ -366,15 +382,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, reinterpret_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(ERROR) << "half size is not supported yet!"; - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "Half outputs are not supported!")); - return; + LOG(WARNING) << "half size is not supported yet!"; + return kRetry; case nvinfer1::DataType::kINT8: - LOG(ERROR) << "int8 is not supported yet!"; - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "INT8 outputs are not supported!")); - return; + LOG(WARNING) << "int8 is not supported yet!"; + return kRetry; #if NV_TENSORRT_MAJOR > 3 case nvinfer1::DataType::kINT32: buffers[binding_index] = @@ -382,13 +394,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, break; #endif default: - LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "Unsupported output data type! ", static_cast(dtype))); - return; + LOG(WARNING) << "Unknown TRT data type: " << static_cast(dtype); + return kRetry; } } - // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files const cudaStream_t* stream = CHECK_NOTNULL( reinterpret_cast(ctx->op_device_context() ->stream() @@ -396,15 +406,15 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, ->GpuStreamMemberHack())); // TODO(jie): trt enqueue does not return error - auto& trt_execution_context_ptr = engine_ctx_pair.second; auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, nullptr); if (!ret) { - LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name(); - ctx->SetStatus(tensorflow::errors::Internal( - "Failed to enqueue batch for TRT engine: ", name())); + LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); + return kRetry; } - // sync should be done by TF. + test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done"); + // Synchronization will be done by TF. + return !kRetry; } TRTEngineOp::~TRTEngineOp() { @@ -424,8 +434,6 @@ nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { if (!alloc) { LOG(ERROR) << "Can't find device allocator for gpu device " << device->name(); - ctx->SetStatus(tensorflow::errors::Internal( - "Can't get device allocator for device ", device->name())); return nullptr; } allocator_.reset(new TRTDeviceAllocator(alloc)); @@ -452,7 +460,6 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, #if NV_TENSORRT_MAJOR > 3 auto allocator = GetAllocator(ctx); if (allocator == nullptr) { - // GetAllocator already set the Status. return null_pair; } infer->setGpuAllocator(allocator); @@ -469,7 +476,9 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, raw_static_engine->createExecutionContext())}; // Runtime is safe to delete after engine creation serialized_segment_.clear(); - if (max_batch_size < batch_size) return null_pair; + if (max_batch_size < batch_size) { + return null_pair; + } return engine_map_.at(max_batch_size); } // static_engine_ @@ -481,7 +490,6 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, #if NV_TENSORRT_MAJOR > 3 allocator = GetAllocator(ctx); if (allocator == nullptr) { - // GetAllocator already set the Status. return null_pair; } #endif @@ -505,9 +513,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, // retry in the future. engine_map_[batch_size] = {nullptr, nullptr}; } - LOG(ERROR) << "Engine creation for batch size " << batch_size - << " failed " << status; - ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!")); + LOG(WARNING) << "Engine creation for batch size " << batch_size + << " failed " << status; return null_pair; } VLOG(1) << "Conversion is done"; @@ -519,7 +526,7 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, } tensorflow::Status TRTEngineOp::AllocateCalibrationResources( - tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) { + OpKernelContext* ctx, TRTCalibrationResource** cr) { auto cres = new TRTCalibrationResource(); *cr = cres; // Get the allocator. @@ -583,7 +590,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( // TODO(aaroey): maybe setting the max batch size using the python // calibration wrapper class. auto s = convert::ConvertGraphDefToEngine( - *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(), + *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(), &cres->engine_, /*convert_successfully=*/nullptr); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index 9265250605dc7bbf2d85c6853bc6e8bf379ce72f..8fe06758914261035c90a6fda3f114a63a8ac93a 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -35,7 +35,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class TRTInt8Calibrator; +struct TRTInt8Calibrator; class TRTCalibrationResource; class AsyncHelper; // TODO(Sami): Remove this file? @@ -60,6 +60,12 @@ class TRTEngineOp : public AsyncOpKernel { // Execute replaced native segment as function Op. void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + // Execute the tensorrt engine. Returns whether we need to retry by running + // the native segment. + bool ExecuteTrtEngine(OpKernelContext* ctx, const int num_batch, + nvinfer1::ICudaEngine* trt_engine_ptr, + nvinfer1::IExecutionContext* trt_execution_context_ptr); + // Allocate necessary resources for calibration Status AllocateCalibrationResources(OpKernelContext* ctx, TRTCalibrationResource** cr); diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index fe4fa166a10d914d028938925266683e62861421..7cdfe2b1a612be2eec473d806d0eb44b611ca68a 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -20,7 +20,11 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph +from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph +from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value +from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 2b67931661397cee0de9faa66b58a608c69ecdc5..4116f2fe30aa5c0c9ea139100291abe3b13da94b 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -20,26 +20,26 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six +from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert +from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values +from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled -from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl as _impl -from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.grappler import tf_optimizer from tensorflow.python.platform import tf_logging -from tensorflow.python.util import compat - +from tensorflow.python.training import saver # pylint: enable=unused-import,line-too-long -# TODO(skama): get outputs from session when implemented as c++ -# optimization pass def create_inference_graph(input_graph_def, outputs, max_batch_size=1, @@ -48,7 +48,7 @@ def create_inference_graph(input_graph_def, minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=[]): + cached_engine_batches=None): """Python wrapper for the TRT transformation. Args: @@ -87,8 +87,7 @@ def create_inference_graph(input_graph_def, (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version])) + ". Please make sure that correct version of TensorRT " + - "is available in the system and added to ldconfig or LD_LIBRARY_PATH" - ) + "is available in the system and added to ldconfig or LD_LIBRARY_PATH") raise RuntimeError("Incompatible TensorRT library version") for i in zip(loaded_version, compiled_version): if i[0] != i[1]: @@ -121,41 +120,42 @@ def create_inference_graph(input_graph_def, to_bytes = py3bytes to_string = py3string - out_names = [] - for i in outputs: - if isinstance(i, ops.Tensor): - out_names.append(to_bytes(i.name)) - else: - out_names.append(to_bytes(i)) - - input_graph_def_str = input_graph_def.SerializeToString() - - # TODO(sami): Fix this when we can return status from C++ library - # There is a problem with the TF internal library setup that doesn't - # allow us to return a status object from C++. Thus we return a - # pair or strings where first one is encoded status and the second - # one is the transformed graphs protobuf string. - out = trt_convert(input_graph_def_str, out_names, max_batch_size, - max_workspace_size_bytes, mode, minimum_segment_size, - is_dynamic_op, maximum_cached_engines, - cached_engine_batches) - status = to_string(out[0]) - output_graph_def_string = out[1] - del input_graph_def_str # Save some memory - if len(status) < 2: - raise _impl.UnknownError(None, None, status) - if status[:2] != "OK": - msg = status.split(";") - if len(msg) == 1: - raise RuntimeError("Status message is malformed {}".format(status)) - # pylint: disable=protected-access - raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), - int(msg[0])) - # pylint: enable=protected-access - output_graph_def = graph_pb2.GraphDef() - output_graph_def.ParseFromString(output_graph_def_string) - del output_graph_def_string # Save some memory - return output_graph_def + # Create MetaGraphDef + graph = ops.Graph() + with graph.as_default(): + importer.import_graph_def(input_graph_def, name="") + meta_graph = saver.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph) + if outputs: + output_collection = meta_graph_pb2.CollectionDef() + output_list = output_collection.node_list.value + for i in outputs: + if isinstance(i, ops.Tensor): + output_list.append(to_bytes(i.name)) + else: + output_list.append(to_bytes(i)) + meta_graph.collection_def["train_op"].CopyFrom(output_collection) + + # Create RewriterConfig. + rewriter_cfg = rewriter_config_pb2.RewriterConfig() + rewriter_cfg.optimizers.extend(["constfold", "layout"]) + optimizer = rewriter_cfg.custom_optimizers.add() + optimizer.name = "TensorRTOptimizer" + optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size + optimizer.parameter_map["max_batch_size"].i = max_batch_size + optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op + optimizer.parameter_map[ + "max_workspace_size_bytes"].i = max_workspace_size_bytes + optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode) + optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines + if cached_engine_batches: + if not isinstance(cached_engine_batches, list): + raise TypeError("cached_engine_batches should be a list.") + optimizer.parameter_map["cached_engine_batches"].list.i.extend( + cached_engine_batches) + + return tf_optimizer.OptimizeGraph( + rewriter_cfg, meta_graph, graph_id=b"tf_graph") def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 008fffc95430b1c423788a4e958e06e700cac233..b43f1b190f5f8cfe98959dd9f2838e4d45759e5c 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -414,10 +414,10 @@ tensorflow::Status SegmentGraph( } for (const SimpleNode* node : order) { // All output nodes of 'node' have been visited... - VLOG(2) << "Trying node " << node->name() << " id=" << node->id(); + VLOG(3) << "Trying node " << node->name() << " id=" << node->id(); // 'node' must be a TRT candidate... if (node_segments[node->id()].Value() == nullptr) { - VLOG(2) << "... not a TRT candidate"; + VLOG(3) << "... not a TRT candidate"; continue; } // Contract output edges to combine 'node' with output @@ -426,22 +426,22 @@ tensorflow::Status SegmentGraph( while (true) { std::set contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { - VLOG(2) << "... out node " << out_edge->dst()->name() << " ( " + VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; if (out_edge->IsControlEdge()) { - VLOG(2) << "... ... Control Edge, Skipping"; + VLOG(3) << "... ... Control Edge, Skipping"; continue; } // Out node must be TRT candidate... if (node_segments[out_edge->dst()->id()].Value() == nullptr) { - VLOG(2) << "... ... not a TRT candidate"; + VLOG(3) << "... ... not a TRT candidate"; continue; } if (CanContractEdge(out_edge, graph)) { - VLOG(2) << "... ... can contract"; + VLOG(3) << "... ... can contract"; contract_edges.insert(out_edge); } else { - VLOG(2) << "... ... cannot contract, would form cycle"; + VLOG(3) << "... ... cannot contract, would form cycle"; } } if (contract_edges.empty()) { @@ -454,7 +454,7 @@ tensorflow::Status SegmentGraph( const SimpleNode* src = contract_edge->src(); const SimpleNode* dst = contract_edge->dst(); - VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " (" + VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " (" << src->id() << " <- " << dst->id(); node_segments[src->id()].Merge(&node_segments[dst->id()]); @@ -478,7 +478,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::unordered_map> sg_map; + std::map> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -558,27 +558,36 @@ tensorflow::Status SegmentGraph( // then after doing this operation the resulting subgraph will keep the // same properties 1 and 2. // - // For simplicity we use heuristics: for input nodes remove all its - // input, for output nodes remove all its output. In this way, for common - // cases the number of removed nodes should be minimum. + // For simplicity we use heuristics: for input and const output nodes + // remove all their inputs, and for non-const output nodes remove all + // their outputs. In this way, for common cases the number of removed + // nodes should be minimum. auto remove_nodes = [&segment_nodes]( bool is_input_nodes, std::deque* que) { // Run a BFS on the queue to find all the input/output nodes. std::set visited; + std::set logged(que->begin(), que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); if (!visited.insert(node).second) continue; segment_nodes.erase(node); - for (auto in : - is_input_nodes ? node->in_nodes() : node->out_nodes()) { + for (auto in : (is_input_nodes || node->type_string() == "Const") + ? node->in_nodes() + : node->out_nodes()) { if (segment_nodes.count(in)) { que->push_back(in); - VLOG(2) << "Need to remove node " << in->name() - << " because one of its " - << (is_input_nodes ? "output" : "input") - << " nodes in the graph was removed: " << node->name(); + if (VLOG_IS_ON(2)) { + if (!logged.count(in)) { + VLOG(2) << "----> Need to remove node " << in->name() + << " because one of its " + << (is_input_nodes ? "output" : "input") + << " nodes in the graph was removed: " + << node->name(); + logged.insert(in); + } + } } } } @@ -594,7 +603,7 @@ tensorflow::Status SegmentGraph( for (const auto& itr : sg_map) { const std::set& segment_nodes = itr.second; if (VLOG_IS_ON(1)) { - string s; + string s = "parent=" + itr.first + ":"; for (auto node : segment_nodes) s += " " + node->name(); VLOG(1) << "Segment " << segments->size() << ": " << s; } diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 432e7b1c047cb3b22d47f7432b6aad639a3a3b2d..5937fa8259a39339e92b150862d195ee1f23f70a 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -206,7 +206,7 @@ TEST_F(SegmentTest, Multiple) { // Make add5 not a TRT candidate, and we expect two segments. auto without_add5 = all_adds - "add5"; RunTest(&g, without_add5, without_add5, without_add5, - {{"add6", "add8"}, {"add0", "add1", "add2", "add3"}}); + {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}}); // Make add8 not a candidate and add6 not an input candidate, then all direct // and indirect inputs of add6 will be removed from the segment. @@ -252,7 +252,7 @@ TEST_F(SegmentTest, BigIfElse) { const std::set all_adds = {"add0", "add1", "add2", "add3", "add4", "add5", "add6", "add7"}; RunTest(&g, all_adds - "add2", all_adds, all_adds, - {{"add3", "add4", "add5", "add6", "add7"}, {"add0", "add1"}}); + {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}}); } } // namespace test diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc index 3712a9a6fe349d949ef2666652b9d750538d5535..769982c6456f76663e50fe3ec59651127e3720ac 100644 --- a/tensorflow/contrib/tensorrt/tensorrt_test.cc +++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/test.h" #if GOOGLE_CUDA @@ -130,6 +132,13 @@ void Execute(nvinfer1::IExecutionContext* context, const float* input, } TEST(TensorrtTest, BasicFunctions) { + // Handle the case where the test is run on machine with no gpu available. + if (CHECK_NOTNULL(GPUMachineManager())->VisibleDeviceCount() <= 0) { + LOG(WARNING) << "No gpu device available, probably not being run on a gpu " + "machine. Skipping..."; + return; + } + // Create the network model. nvinfer1::IHostMemory* model = CreateNetwork(); // Use the model to create an engine and then an execution context. diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index edd30ad7a95dd3c7f74634699660caad30c0b645..8ea5a6373525a8045d13f70aa9e12d66d4c08f0a 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -20,17 +20,19 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test -class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): +class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing single segment.""" @@ -65,13 +67,17 @@ class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", + # "relu", "identity", "max_pool"] + expected_engines=["my_trt_op_0"], expected_output_dims=(100, 6, 6, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) -class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): +class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing multiple segment.""" @@ -95,32 +101,246 @@ class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase): padding="SAME", name="conv") c1 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype) - p = conv * c1 + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1") + p = math_ops.mul(conv, c1, name="mul") c2 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype) - q = conv / c2 + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2") + q = math_ops.div(conv, c2, name="div") - edge = self.trt_incompatible_op(q) - edge /= edge - r = edge + edge + edge = self.trt_incompatible_op(q, name="incompatible") + edge = math_ops.div(edge, edge, name="div1") + r = math_ops.add(edge, edge, name="add") - p -= edge - q *= edge - s = p + q - s -= r + p = math_ops.sub(p, edge, name="sub") + q = math_ops.mul(q, edge, name="mul1") + s = math_ops.add(p, q, name="add1") + s = math_ops.sub(s, r, name="sub1") array_ops.squeeze(s, name=self.output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", + # "add", "sub1"]; + # - my_trt_op_1 should have ["weights","conv", "div"] + expected_engines=["my_trt_op_0", "my_trt_op_1"], expected_output_dims=(100, 12, 12, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) -# TODO(aaroey): add a large complex graph to test. +class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestA, self).setUp() + # Let it fail to build the second engine. + trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + for i in range(2): + c = constant_op.constant(1.0, name="c%d" % i) + n = math_ops.add(n, c, name="add%d" % i) + n = math_ops.mul(n, n, name="mul%d" % i) + edge = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([edge]): + c = constant_op.constant(1.0, name="c2") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul2") + c = constant_op.constant(1.0, name="c3") + n = math_ops.add(n, c, name="add3") + n = math_ops.mul(n, n, name="mul3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + # Only the first engine is built. + "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class PartiallyConvertedTestB(PartiallyConvertedTestA): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestB, self).setUp() + # Let it fail to build the first engine. + trt_convert.clear_test_values("") + trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + return super(PartiallyConvertedTestB, self).GetParams()._replace( + expected_engines={ + # Only the second engine is built. + "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + }) + + +class ConstInputTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + # Adds control dependency from the constant op to a trt incompatible op, + # and adds control dependency from the trt incompatible op to all other + # ops, to make sure the constant op cannot be contracted with any trt + # segment that depends on it. + with g.control_dependencies([c]): + d = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add", "add1", "mul"], + "my_trt_op_1": ["add2", "add3", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]}, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add2", "add3", "mul1"], + # Why segment ["add", "add1", "mul"] was assigned segment id 1 + # instead of 0: the parent node of this segment is actually const + # node 'c', but it's removed later since it's const output of the + # segment which is not allowed. + "my_trt_op_1": ["add", "add1", "mul"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + c1 = constant_op.constant(1.0, name="c1") + c2 = constant_op.constant(1.0, name="c2") + d1 = constant_op.constant(1.0, name="d1") + d2 = self.trt_incompatible_op(inp, name="d2") + with g.control_dependencies([d1, d2]): + add = math_ops.add(inp, c1, name="add") + with g.control_dependencies([d1, d2]): + mul = math_ops.mul(add, add, name="mul") + with g.control_dependencies([d1, d2]): + add1 = math_ops.add(mul, mul, name="add1") + edge = self.trt_incompatible_op(add1, name="incompatible") + with g.control_dependencies([d1, d2, add, mul]): + add2 = math_ops.add(edge, c2, name="add2") + with g.control_dependencies([d1, d2, add1, mul]): + mul1 = math_ops.mul(add2, add2, name="mul1") + with g.control_dependencies([d1, d2, add, add1]): + add3 = math_ops.add(mul1, mul1, name="add3") + array_ops.squeeze(add3, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["c1", "add", "add1", "mul"], + "my_trt_op_1": ["c2", "add2", "add3", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py index 730b6843fb9885b8ba0db2ad199b95d9d3219774..2e1107e30383926f6428c6551682caf66cd97498 100644 --- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -66,7 +66,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name, w1_name, w2_name], input_dims=[input_dims, w1_dims, w2_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(12, 5, 8, 7), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 0c03a10b640c8b243318bb4327d2ac5aac803be7..8be32f59b48e64412466370950298feafc03b35c 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -102,7 +102,10 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=7, + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4", "my_trt_op_5", "my_trt_op_6" + ], expected_output_dims=(48, 89), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py index dd673463a5930df4d0e4c1c7410b3f5eb88d664c..9316b14da07d5f7e47953504680e14d5d20c17a4 100644 --- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -109,7 +109,24 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=16, + expected_engines=[ + "my_trt_op_0", + "my_trt_op_1", + "my_trt_op_2", + "my_trt_op_3", + "my_trt_op_4", + "my_trt_op_5", + "my_trt_op_6", + "my_trt_op_7", + "my_trt_op_8", + "my_trt_op_9", + "my_trt_op_10", + "my_trt_op_11", + "my_trt_op_12", + "my_trt_op_13", + "my_trt_op_14", + "my_trt_op_15", + ], expected_output_dims=(5, 23040), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py index 8c51c45b0a2c6f370415b9c8ac99a63dd37be900..1874b9dd45390407d3d36798cae620848df50c8d 100644 --- a/tensorflow/contrib/tensorrt/test/concatenation_test.py +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -73,7 +73,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(2, 126), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py index 97b29bf05ddc3a0396472d0500ff53ceca7c5d4b..8c59000b70e04cedc84308249865cfcb23ce80a3 100644 --- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -58,7 +58,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=['my_trt_op_0'], expected_output_dims=(5, 12, 12, 1), allclose_atol=1.e-02, allclose_rtol=1.e-02) diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66eb6be757d3f4dcc390435486f7ed4f6517f875 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of BatchMatMul in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 15, 15, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + e1 = constant_op.constant( + np.random.randn(1, 1, 3, 5), name="kernel_1", dtype=dtype) + e2 = constant_op.constant( + np.random.randn(1, 1, 5, 10), name="kernel_2", dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=e1, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + out = nn.conv2d( + input=conv, + filter=e2, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv_2") + array_ops.squeeze(out, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(2, 15, 15, 10), + allclose_atol=1.e-02, + allclose_rtol=1.e-02) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py index 734ccf6345777d543138daba2b720c9dc03f3295..fd55b8cd99171fe34424e48a417eb8981b051c17 100644 --- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -77,7 +77,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + expected_engines=["my_trt_op_0", "my_trt_op_1"], expected_output_dims=(2, 4, 5, 4), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py index 50265c0845005748d75bf8afc49df11a528c9169..51c905a50b29c017719d66f9049e9b1bc3a9ec97 100644 --- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -51,15 +51,18 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): name="conv") b = constant_op.constant( np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) - t = conv * b - e = gen_math_ops.tan(conv) - t = t - e + t = math_ops.mul(conv, b, name="mul") + e = self.trt_incompatible_op(conv, name="incompatible") + t = math_ops.sub(t, e, name="sub") array_ops.squeeze(t, name=self.output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + expected_engines={ + "my_trt_op_0": ["bias", "mul", "sub"], + "my_trt_op_1": ["weights", "conv"] + }, expected_output_dims=(2, 4, 5, 4), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 60b8eb6e8115011c5664212de1eba27259f1b05b..6f85ada4649563d099c6054e8e17da27954071f7 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -20,6 +20,7 @@ from __future__ import print_function from collections import namedtuple import itertools +import os import warnings import numpy as np import six @@ -30,6 +31,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -37,10 +39,14 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ - "gdef", "input_names", "input_dims", "num_expected_engines", + "gdef", "input_names", "input_dims", "expected_engines", "expected_output_dims", "allclose_atol", "allclose_rtol" ]) +RunParams = namedtuple( + "RunParams", + ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"]) + PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -48,6 +54,12 @@ def _IsQuantizationMode(mode): return mode == "INT8" +class GraphState(object): + ORIGINAL = 0 + CALIBRATE = 1 + INFERENCE = 2 + + class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" @@ -63,50 +75,96 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def precision_modes(self): return ["FP32", "FP16", "INT8"] + # str is bytes in py2, but unicode in py3. + def _ToUnicode(self, s): + if six.PY2: + if isinstance(s, unicode): + return s + return s.decode("utf-8") + else: + if isinstance(s, str): + return s + return s.decode("utf-8") + def _ToBytes(self, s): if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") return s else: - return s.encode("utf-8") + if isinstance(s, str): + return s.encode("utf-8") + return s def _ToString(self, s): if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") return s else: + if isinstance(s, str): + return s return s.decode("utf-8") + @classmethod + def setUpClass(cls): + """Setup method for the module.""" + super(TfTrtIntegrationTestBase, cls).setUpClass() + trt_convert.enable_test_value() + def setUp(self): """Setup method.""" super(TfTrtIntegrationTestBase, self).setUp() warnings.simplefilter("always") + trt_convert.clear_test_values("") def GetParams(self): """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" raise NotImplementedError() - def _GetConfigProto(self, - params, - use_optimizer, - precision_mode=None, - is_dynamic_op=None): + def _PrepareRun(self, params, graph_state): + """Set up necessary testing environment before calling sess.run().""" + # Clear test values added by TRTEngineOp. + trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") + + def _VerifyRun(self, params, graph_state): + """Verify the state after sess.run().""" + for engine_name in params.expected_engines: + if graph_state == GraphState.ORIGINAL: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.CALIBRATE: + self._ExpectCalibration(engine_name, "done") + self._ExpectNativeSegment(engine_name, "done") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.INFERENCE: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "done") + + def _GetConfigProto(self, params, run_params, graph_state): """Get config proto based on specific settings.""" - if use_optimizer: + if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: rewriter_cfg = rewriter_config_pb2.RewriterConfig() rewriter_cfg.optimizers.extend(["constfold", "layout"]) custom_op = rewriter_cfg.custom_optimizers.add() custom_op.name = "TensorRTOptimizer" - custom_op.parameter_map["minimum_segment_size"].i = 3 + custom_op.parameter_map["minimum_segment_size"].i = 2 custom_op.parameter_map["max_batch_size"].i = max( [dims[0] for dims in params.input_dims]) - custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op + custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 custom_op.parameter_map["precision_mode"].s = self._ToBytes( - precision_mode) + run_params.precision_mode) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() gpu_options = config_pb2.GPUOptions() + gpu_options.allow_growth = True if trt_convert.get_linked_tensorrt_version()[0] == 3: gpu_options.per_process_gpu_memory_fraction = 0.50 @@ -114,7 +172,26 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): gpu_options=gpu_options, graph_options=graph_options) return config - def _RunGraph(self, params, gdef, input_data, config, num_runs=2): + def _ExpectTestValue(self, engine_name, method, expected_value): + label = "%s:%s" % (engine_name, method) + actual_value = trt_convert.get_test_value(label) + self.assertEqual( + expected_value, + actual_value, + msg="Unexpected test value with label %s. Actual: %s; expected: %s" % + (label, actual_value, expected_value)) + + def _ExpectCalibration(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteCalibration", value) + + def _ExpectTrtEngine(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value) + + def _ExpectNativeSegment(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value) + + def _RunGraph(self, params, gdef, input_data, config, graph_state, + num_runs=2): """Run given graphdef multiple times.""" assert len(params.input_names) == len(input_data) g = ops.Graph() @@ -131,93 +208,170 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): val = None # Defaults to 2 runs to verify result across multiple runs is same. for _ in range(num_runs): + self._PrepareRun(params, graph_state) new_val = sess.run(out, {inp[i]: input_data[i] for i in range(len(inp))}) self.assertEqual(params.expected_output_dims, new_val.shape) if val is not None: self.assertAllEqual(val, new_val) val = new_val + self._VerifyRun(params, graph_state) return val # Use real data that is representative of the inference dataset # for calibration. For this test script it is random data. def _RunCalibration(self, params, gdef, input_data, config): """Run calibration on given graph.""" - return self._RunGraph(params, gdef, input_data, config, 30) + return self._RunGraph( + params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) - def _GetTrtGraphDef(self, params, gdef, precision_mode, is_dynamic_op): + def _GetTrtGraphDef(self, params, run_params, gdef): """Return trt converted graphdef.""" return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=[self.output_name], max_batch_size=max([dims[0] for dims in params.input_dims]), max_workspace_size_bytes=1 << 25, - precision_mode=precision_mode, + precision_mode=run_params.precision_mode, minimum_segment_size=2, - is_dynamic_op=is_dynamic_op) - - def _VerifyGraphDef(self, - params, - gdef, - precision_mode=None, - is_calibrated=None, - dynamic_engine=None): + is_dynamic_op=run_params.dynamic_engine) + + def _WriteGraph(self, params, run_params, gdef, graph_state): + if graph_state == GraphState.ORIGINAL: + label = "Original" + elif graph_state == GraphState.CALIBRATE: + label = "CalibEngine" + elif graph_state == GraphState.INFERENCE: + label = "InferEngine" + graph_name = ( + self.__class__.__name__ + "_" + run_params.test_name + "_" + label + + ".pbtxt") + temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir()) + logging.info("Writing graph to %s/%s", temp_dir, graph_name) + graph_io.write_graph(gdef, temp_dir, graph_name) + + def _VerifyConnections(self, params, converted_gdef): + old_to_new_node_map = { + self._ToString(node.name): self._ToString(node.name) + for node in params.gdef.node + } + for engine_name, node_names in params.expected_engines.items(): + for node_name in node_names: + old_to_new_node_map[node_name] = engine_name + name_to_node_map = { + self._ToString(node.name): node for node in params.gdef.node + } + + def _InputName(inp): + inp = self._ToString(inp) + prefix = "" + if inp[0] == "^": + prefix = "^" + inp = inp[1:] + parts = inp.split(":") + if len(parts) > 1 and parts[-1].isdigit(): + inp = inp[:-len(parts[-1]) - 1] + return (prefix, inp) + + expected_input_map = {} + for node in params.gdef.node: + name_str = self._ToString(node.name) + target_node_name = old_to_new_node_map[name_str] + is_engine_op = (target_node_name != name_str) + if target_node_name not in expected_input_map: + expected_input_map[target_node_name] = set() + input_set = expected_input_map[target_node_name] + for inp in node.input: + (prefix, inp_name) = _InputName(inp) + # Add the input only if it's outside the segment (note that it could be + # in a different engine). + if (not is_engine_op or + old_to_new_node_map[inp_name] != target_node_name): + if is_engine_op and name_to_node_map[inp_name].op == "Const": + # Const data input nodes to the segment has been copied to the + # segment graphdef and the engine, and the dependency has been + # converted to control dependendy. + input_set.add("^" + old_to_new_node_map[inp_name]) + else: + input_set.add(prefix + old_to_new_node_map[inp_name]) + + actual_input_map = {} + for node in converted_gdef.node: + name_str = self._ToString(node.name) + actual_input_map[name_str] = set() + input_set = actual_input_map[name_str] + for inp in node.input: + (prefix, node_name) = _InputName(inp) + input_set.add(prefix + node_name) + + self.assertEqual( + expected_input_map, + actual_input_map, + msg="expected:\n%s\nvs actual:\n%s" % (sorted( + expected_input_map.items()), sorted(actual_input_map.items()))) + + def _VerifyGraphDef(self, params, run_params, gdef, graph_state): + self._WriteGraph(params, run_params, gdef, graph_state) + num_engines = 0 - for n in gdef.node: - # TODO(jie): we should have coverage for failed conversion (TF fallback). - # where the conversion will fail and we shouldn't count this engine as the - # converted engines. - if n.op == "TRTEngineOp": + for node in gdef.node: + if node.op == "TRTEngineOp": num_engines += 1 - self.assertNotEqual(self._ToBytes(""), n.attr["serialized_segment"].s) - self.assertNotEqual(self._ToBytes(""), n.attr["segment_funcdef_name"].s) + self.assertTrue(node.name in params.expected_engines) + self.assertTrue(len(node.attr["serialized_segment"].s)) + self.assertTrue(len(node.attr["segment_funcdef_name"].s)) self.assertEqual( - self._ToBytes(precision_mode), n.attr["precision_mode"].s) - self.assertEqual(not dynamic_engine, n.attr["static_engine"].b) - if _IsQuantizationMode(precision_mode) and is_calibrated: - self.assertNotEqual(self._ToBytes(""), n.attr["calibration_data"].s) + self._ToBytes(run_params.precision_mode), + node.attr["precision_mode"].s) + + is_dynamic_engine = not node.attr["static_engine"].b + self.assertEqual(run_params.dynamic_engine, is_dynamic_engine) + + has_calibration_data = len(node.attr["calibration_data"].s) + if (_IsQuantizationMode(run_params.precision_mode) and + graph_state == GraphState.INFERENCE): + self.assertTrue(has_calibration_data) else: - self.assertEqual(self._ToBytes(""), n.attr["calibration_data"].s) - if precision_mode is None: # This means gdef is the original GraphDef. + self.assertFalse(has_calibration_data) + if graph_state == GraphState.ORIGINAL: self.assertEqual(0, num_engines) else: - self.assertEqual(num_engines, params.num_expected_engines) + self.assertEqual(num_engines, len(params.expected_engines)) + if isinstance(params.expected_engines, dict): + self._VerifyConnections(params, gdef) + # TODO(aaroey): consider verifying the corresponding TF function. - def RunTest(self, params, use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine): - assert precision_mode in PRECISION_MODES + def RunTest(self, params, run_params): + assert run_params.precision_mode in PRECISION_MODES input_data = [np.random.random_sample(dims) for dims in params.input_dims] input_gdef = params.gdef - self._VerifyGraphDef(params, input_gdef) + self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL) # Get reference result without running trt. - config_no_trt = self._GetConfigProto(params, False) + config_no_trt = self._GetConfigProto(params, run_params, + GraphState.ORIGINAL) logging.info("Running original graph w/o trt, config:\n%s", str(config_no_trt)) - ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt) + ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt, + GraphState.ORIGINAL) # Run calibration if necessary. - if _IsQuantizationMode(precision_mode): + if _IsQuantizationMode(run_params.precision_mode): - calib_config = self._GetConfigProto(params, use_optimizer, precision_mode, - dynamic_calib_engine) + calib_config = self._GetConfigProto(params, run_params, + GraphState.CALIBRATE) logging.info("Running calibration graph, config:\n%s", str(calib_config)) - if use_optimizer: - self.assertTrue(False) - # TODO(aaroey): uncomment this and get infer_gdef when this mode is - # supported. - # result = self._RunCalibration(params, input_gdef, input_data, - # calib_config) + if run_params.use_optimizer: + result = self._RunCalibration(params, input_gdef, input_data, + calib_config) else: - calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode, - dynamic_calib_engine) - self._VerifyGraphDef(params, calib_gdef, precision_mode, False, - dynamic_calib_engine) + calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef) + self._VerifyGraphDef(params, run_params, calib_gdef, + GraphState.CALIBRATE) result = self._RunCalibration(params, calib_gdef, input_data, calib_config) - infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) - self._VerifyGraphDef(params, infer_gdef, precision_mode, True, - dynamic_calib_engine) + infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) + self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE) self.assertAllClose( ref_result, @@ -228,18 +382,19 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): infer_gdef = input_gdef # Run inference. - infer_config = self._GetConfigProto(params, use_optimizer, precision_mode, - dynamic_infer_engine) + infer_config = self._GetConfigProto(params, run_params, + GraphState.INFERENCE) logging.info("Running final inference graph, config:\n%s", str(infer_config)) - if use_optimizer: - result = self._RunGraph(params, infer_gdef, input_data, infer_config) + if run_params.use_optimizer: + result = self._RunGraph(params, infer_gdef, input_data, infer_config, + GraphState.INFERENCE) else: - trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode, - dynamic_infer_engine) - self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True, - dynamic_infer_engine) - result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config) + trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef) + self._VerifyGraphDef(params, run_params, trt_infer_gdef, + GraphState.INFERENCE) + result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config, + GraphState.INFERENCE) self.assertAllClose( ref_result, @@ -262,66 +417,44 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _AddTests(test_class): """Adds test methods to TfTrtIntegrationTestBase.""" - def _GetTest(use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine): + def _GetTest(run_params): """Gets a single test method based on the parameters.""" def _Test(self): params = self.GetParams() logging.info( - "Running test with parameters: use_optimizer=%s, precision_mode=%s, " - "dynamic_infer_engine=%s, dynamic_calib_engine=%s", use_optimizer, - precision_mode, dynamic_infer_engine, dynamic_calib_engine) - self.RunTest(params, use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine) + "Running test %s with parameters: use_optimizer=%s, " + "precision_mode=%s, dynamic_engine=%s", + "testTfTrt_" + run_params.test_name, run_params.use_optimizer, + run_params.precision_mode, run_params.dynamic_engine) + self.RunTest(params, run_params) return _Test use_optimizer_options = [False, True] - dynamic_infer_engine_options = [False, True] - dynamic_calib_engine_options = [False, True] - for (use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine) in itertools.product( - use_optimizer_options, PRECISION_MODES, dynamic_infer_engine_options, - dynamic_calib_engine_options): + dynamic_engine_options = [False, True] + for (use_optimizer, precision_mode, dynamic_engine) in itertools.product( + use_optimizer_options, PRECISION_MODES, dynamic_engine_options): if _IsQuantizationMode(precision_mode): - if not dynamic_calib_engine and dynamic_infer_engine: - # TODO(aaroey): test this case, the conversion from static calibration - # engine to dynamic inference engine should be a noop. - continue if use_optimizer: # TODO(aaroey): if use_optimizer is True we need to get the inference # graphdef using custom python wrapper class, which is not currently # supported yet. continue - if not dynamic_calib_engine: + if not dynamic_engine: # TODO(aaroey): construction of static calibration engine is not # supported yet. continue - if dynamic_calib_engine and not dynamic_infer_engine: - # TODO(aaroey): construction of static inference engine using dynamic - # calibration engine is not supported yet. - continue - else: # In non int8 mode. - if dynamic_calib_engine: - # dynamic_calib_engine doesn't affect non-int8 modes, so just let - # related tests run once on dynamic_calib_engine=False. - continue conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" - infer_engine_type = ("DynamicInferEngine" - if dynamic_infer_engine else "StaticInferEngine") - calib_engine_type = "" - if precision_mode == "INT8": - calib_engine_type = ("DynamicCalibEngine" - if dynamic_calib_engine else "StaticCalibEngine") - test_name = "%s_%s_%s%s" % (conversion, precision_mode, infer_engine_type, - ("_" + calib_engine_type) - if len(calib_engine_type) else "") - setattr( - test_class, "testTfTRT_" + test_name, - _GetTest(use_optimizer, precision_mode, dynamic_infer_engine, - dynamic_calib_engine)) + engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine") + test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type) + run_params = RunParams( + use_optimizer=use_optimizer, + precision_mode=precision_mode, + dynamic_engine=dynamic_engine, + test_name=test_name) + setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params)) if trt_convert.is_tensorrt_enabled(): diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index b9e977cf67b4e94282c10313477276b04ea828aa..500057a36d60efa3b7f96f22e27973444ecc277c 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -100,7 +100,10 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name, input2_name], input_dims=[input_dims, input2_dims], - num_expected_engines=5, + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4" + ], expected_output_dims=(12, 5, 8, 12), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..276308b3a0a6ce864969afb0179c6a3f00d6b70b --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/test/utils.h" + +#include +#include + +#include "re2/re2.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// TODO(aaroey): make this class thread-safe. +class TestValueManager { + public: + static TestValueManager* singleton() { + static TestValueManager* manager = new TestValueManager(); + return manager; + } + + void Enable() { + VLOG(1) << "Enabling test value"; + enabled_ = true; + } + + void Add(const string& label, const string& value) { + if (TF_PREDICT_FALSE(enabled_)) { + QCHECK_NE("", value); + VLOG(1) << "Adding test value: " << label << " -> " << value; + values_.insert({label, value}); + } + } + + string Get(const string& label) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Getting test value by " << label; + auto itr = values_.find(label); + if (itr == values_.end()) return ""; + return itr->second; + } + return ""; + } + + void Clear(const string& pattern) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Clearing test values"; + if (pattern.empty()) { + values_.clear(); + return; + } + std::vector keys_to_clear; + for (const auto& kv : values_) { + if (RE2::FullMatch(kv.first, pattern)) { + keys_to_clear.push_back(kv.first); + } + } + for (const string& key : keys_to_clear) { + values_.erase(key); + } + } + } + + private: + TestValueManager() : enabled_(false) {} + + bool enabled_; + std::unordered_map values_; +}; + +void EnableTestValue() { TestValueManager::singleton()->Enable(); } + +void ClearTestValues(const string& pattern) { + TestValueManager::singleton()->Clear(pattern); +} + +void AddTestValue(const string& label, const string& value) { + TestValueManager::singleton()->Add(label, value); +} + +string GetTestValue(const string& label) { + return TestValueManager::singleton()->Get(label); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4bb4120206cfaae70107e55d1818e3af2f02717a --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// Helper methods to inject values used by testing tools. +void EnableTestValue(); +void ClearTestValues(const string& pattern); +void AddTestValue(const string& label, const string& value); +string GetTestValue(const string& label); + +#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \ + do { \ + if (::tensorflow::tensorrt::test::GetTestValue(label) == \ + value_to_return) { \ + return errors::Internal("Injected manually"); \ + } \ + } while (0) + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py index 2b134c3bce2b36e4530f8f8e58cce8d07c9bb13b..ab4d224db4d88c91c9b06d278b404879d989a834 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -72,7 +72,7 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(5, 6, 2, 2), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py index bec2f23eff3b1799d70519462f42c326d17924c1..56bdf848eadbdde3d5896e415ecd9754ed387eeb 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -63,7 +63,7 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=1, + expected_engines=["my_trt_op_0"], expected_output_dims=(5, 2, 2, 6), allclose_atol=1.e-03, allclose_rtol=1.e-03) diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 422740fdf6ec381dc6f6c01e736ce8b3398586ce..6ea15fb8eff13663625420288a37ba002d57fa47 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -101,82 +101,22 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); #include "tensorflow/core/util/stat_summarizer.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" #include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" %} %ignoreall %unignore tensorflow; -%unignore trt_convert; %unignore calib_convert; %unignore get_linked_tensorrt_version; %unignore get_loaded_tensorrt_version; %unignore is_tensorrt_enabled; +%unignore enable_test_value; +%unignore clear_test_values; +%unignore add_test_value; +%unignore get_test_value; %{ -std::pair trt_convert( - string graph_def_string, // The serialized GraphDef string. - std::vector output_names, - size_t max_batch_size, - size_t max_workspace_size_bytes, - int precision_mode, - int minimum_segment_size, - bool is_dyn_op, - int max_cached_engines, - std::vector cached_engine_batches - // Unfortunately we can't use TF_Status here since it - // is in c/c_api and brings in a lot of other libraries - // which in turn declare ops. These ops are included - // statically in our library and cause an abort when - // module is loaded due to double registration - // until Tensorflow properly exposes these headers - // we have to work around this by returning a string - // and converting it to exception on python side. - //,TF_Status* out_status) { -) { -#if GOOGLE_CUDA && GOOGLE_TENSORRT - string out_status; - - tensorflow::GraphDef graph_def; - if (!graph_def.ParseFromString(graph_def_string)) { - out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; - return std::pair{out_status, ""}; - } - - if (precision_mode < 0 || precision_mode > 2) { - out_status = "InvalidArgument;Invalid precision_mode"; - return std::pair{out_status, ""}; - } - if (!output_names.size()) { - out_status = "InvalidArgument;Size of the output_names vector is 0"; - return std::pair{out_status, ""}; - } - tensorflow::GraphDef out_graph; - tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( - graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &out_graph, precision_mode, minimum_segment_size, - is_dyn_op, max_cached_engines, cached_engine_batches); - if (!conversion_status.ok()) { - auto retCode = (int)conversion_status.code(); - char buff[2000]; - snprintf(buff, 2000, "%d;%s", retCode, - conversion_status.error_message().c_str()); - out_status = buff; - return std::pair{out_status, ""}; - } - string result; - if (!out_graph.SerializeToString(&result)) { - out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; - return std::pair{out_status, ""}; - } - out_status = "OK;All good!"; - return std::pair{out_status, result}; -#else - // Returns FAILED_PRECONDITION. - return std::pair{"9;TensorRT is not enabled!", ""}; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT -} - std::pair calib_convert( string graph_def_string, bool is_dyn_op // unfortunately we can't use TF_Status here since it @@ -251,20 +191,44 @@ bool is_tensorrt_enabled() { return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); } -%} +void enable_test_value() { + tensorflow::tensorrt::test::EnableTestValue(); +} + +#if PY_MAJOR_VERSION < 3 +#define TRT_PY_TO_CPP_STRING PyString_AsString +#define TRT_CPP_TO_PY_STRING PyString_FromString +#else +#define TRT_PY_TO_CPP_STRING PyUnicode_AsUTF8 +#define TRT_CPP_TO_PY_STRING PyUnicode_FromString +#endif + +void clear_test_values(PyObject* pattern) { + tensorflow::tensorrt::test::ClearTestValues( + string(TRT_PY_TO_CPP_STRING(pattern))); +} + +void add_test_value(PyObject* label, PyObject* value) { + tensorflow::tensorrt::test::AddTestValue( + string(TRT_PY_TO_CPP_STRING(label)), string(TRT_PY_TO_CPP_STRING(value))); +} -std::pair calib_convert(string graph_def_string, bool is_dyn_op); +PyObject* get_test_value(PyObject* label) { + string value = tensorflow::tensorrt::test::GetTestValue( + string(TRT_PY_TO_CPP_STRING(label))); + return TRT_CPP_TO_PY_STRING(value.c_str()); +} -std::pair trt_convert(string graph_def_string, - std::vector output_names, - size_t max_batch_size, - size_t max_workspace_size_bytes, - int precision_mode, int minimum_segment_size, - bool is_dyn_op, - int max_cached_engines, - std::vector cached_engine_batches); +%} + +std::pair calib_convert( + string graph_def_string, bool is_dyn_op); version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); bool is_tensorrt_enabled(); +void enable_test_value(); +void clear_test_values(PyObject* pattern); +void add_test_value(PyObject* label, PyObject* value); +PyObject* get_test_value(PyObject* label); %unignoreall diff --git a/tensorflow/contrib/timeseries/__init__.py b/tensorflow/contrib/timeseries/__init__.py index 11db56b1b7a48b401efeece91283eb7084747c14..654a4db098757a969c2d298f7ed490083e63b9da 100644 --- a/tensorflow/contrib/timeseries/__init__.py +++ b/tensorflow/contrib/timeseries/__init__.py @@ -27,6 +27,9 @@ @@TrainEvalFeatures @@FilteringResults + +@@TimeSeriesRegressor +@@OneShotPredictionHead """ from __future__ import absolute_import diff --git a/tensorflow/contrib/timeseries/examples/multivariate.py b/tensorflow/contrib/timeseries/examples/multivariate.py index ed799542fd50cd150f13533c5f33bd67ed09fff6..e81cb18ad7b928a6fd2a748ea6b258c49cf722ae 100644 --- a/tensorflow/contrib/timeseries/examples/multivariate.py +++ b/tensorflow/contrib/timeseries/examples/multivariate.py @@ -80,8 +80,8 @@ def multivariate_train_and_sample( session=session, steps=1)) next_sample = numpy.random.multivariate_normal( # Squeeze out the batch and series length dimensions (both 1). - mean=numpy.squeeze(current_prediction["mean"], axis=[0, 1]), - cov=numpy.squeeze(current_prediction["covariance"], axis=[0, 1])) + mean=numpy.squeeze(current_prediction["mean"], axis=(0, 1)), + cov=numpy.squeeze(current_prediction["covariance"], axis=(0, 1))) # Update model state so that future predictions are conditional on the # value we just sampled. filtering_features = { diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 7020989d6895fd6322db45cda6f7dd99d417d937..0e96c1fbd43ef45e9ff1e090a6d5489ab186484a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -161,6 +161,7 @@ py_test( srcs = [ "head_test.py", ], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip_gpu"], # b/63391119 deps = [ diff --git a/tensorflow/contrib/timeseries/python/timeseries/__init__.py b/tensorflow/contrib/timeseries/python/timeseries/__init__.py index c683dad71de8f8502f08a4e823faa79d60d5604d..8462138339cda8557d9c9ee6e79d4c7a67ad1aa7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/__init__.py +++ b/tensorflow/contrib/timeseries/python/timeseries/__init__.py @@ -24,5 +24,6 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils from tensorflow.contrib.timeseries.python.timeseries.ar_model import * from tensorflow.contrib.timeseries.python.timeseries.estimators import * from tensorflow.contrib.timeseries.python.timeseries.feature_keys import * +from tensorflow.contrib.timeseries.python.timeseries.head import * from tensorflow.contrib.timeseries.python.timeseries.input_pipeline import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 769183f40ad269954dac70db393207c266052144..0ddc4b4144da25206735b0480aa0886374ed43a8 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.training import training as train from tensorflow.python.util import nest @@ -79,12 +80,137 @@ class TimeSeriesRegressor(estimator_lib.Estimator): model_dir=model_dir, config=config) - # TODO(allenl): A parsing input receiver function, which takes a serialized - # tf.Example containing all features (times, values, any exogenous features) - # and serialized model state (possibly also as a tf.Example). - def build_raw_serving_input_receiver_fn(self, - default_batch_size=None, - default_series_length=None): + def _model_start_state_placeholders( + self, batch_size_tensor, static_batch_size=None): + """Creates placeholders with zeroed start state for the current model.""" + gathered_state = {} + # Models may not know the shape of their state without creating some + # variables/ops. Avoid polluting the default graph by making a new one. We + # use only static metadata from the returned Tensors. + with ops.Graph().as_default(): + self._model.initialize_graph() + # Evaluate the initial state as same-dtype "zero" values. These zero + # constants aren't used, but are necessary for feeding to + # placeholder_with_default for the "cold start" case where state is not + # fed to the model. + def _zeros_like_constant(tensor): + return tensor_util.constant_value(array_ops.zeros_like(tensor)) + start_state = nest.map_structure( + _zeros_like_constant, self._model.get_start_state()) + for prefixed_state_name, state in ts_head_lib.state_to_dictionary( + start_state).items(): + state_shape_with_batch = tensor_shape.TensorShape( + (static_batch_size,)).concatenate(state.shape) + default_state_broadcast = array_ops.tile( + state[None, ...], + multiples=array_ops.concat( + [batch_size_tensor[None], + array_ops.ones(len(state.shape), dtype=dtypes.int32)], + axis=0)) + gathered_state[prefixed_state_name] = array_ops.placeholder_with_default( + input=default_state_broadcast, + name=prefixed_state_name, + shape=state_shape_with_batch) + return gathered_state + + def build_one_shot_parsing_serving_input_receiver_fn( + self, filtering_length, prediction_length, default_batch_size=None, + values_input_dtype=None, truncate_values=False): + """Build an input_receiver_fn for export_savedmodel accepting tf.Examples. + + Only compatible with `OneShotPredictionHead` (see `head`). + + Args: + filtering_length: The number of time steps used as input to the model, for + which values are provided. If more than `filtering_length` values are + provided (via `truncate_values`), only the first `filtering_length` + values are used. + prediction_length: The number of time steps requested as predictions from + the model. Times and all exogenous features must be provided for these + steps. + default_batch_size: If specified, must be a scalar integer. Sets the batch + size in the static shape information of all feature Tensors, which means + only this batch size will be accepted by the exported model. If None + (default), static shape information for batch sizes is omitted. + values_input_dtype: An optional dtype specification for values in the + tf.Example protos (either float32 or int64, since these are the numeric + types supported by tf.Example). After parsing, values are cast to the + model's dtype (float32 or float64). + truncate_values: If True, expects `filtering_length + prediction_length` + values to be provided, but only uses the first `filtering_length`. If + False (default), exactly `filtering_length` values must be provided. + + Returns: + An input_receiver_fn which may be passed to the Estimator's + export_savedmodel. + + Expects features contained in a vector of serialized tf.Examples with + shape [batch size] (dtype `tf.string`), each tf.Example containing + features with the following shapes: + times: [filtering_length + prediction_length] integer + values: [filtering_length, num features] floating point. If + `truncate_values` is True, expects `filtering_length + + prediction_length` values but only uses the first `filtering_length`. + all exogenous features: [filtering_length + prediction_length, ...] + (various dtypes) + """ + if values_input_dtype is None: + values_input_dtype = dtypes.float32 + if truncate_values: + values_proto_length = filtering_length + prediction_length + else: + values_proto_length = filtering_length + + def _serving_input_receiver_fn(): + """A receiver function to be passed to export_savedmodel.""" + times_column = feature_column.numeric_column( + key=feature_keys.TrainEvalFeatures.TIMES, dtype=dtypes.int64) + values_column = feature_column.numeric_column( + key=feature_keys.TrainEvalFeatures.VALUES, dtype=values_input_dtype, + shape=(self._model.num_features,)) + parsed_features_no_sequence = ( + feature_column.make_parse_example_spec( + list(self._model.exogenous_feature_columns) + + [times_column, values_column])) + parsed_features = {} + for key, feature_spec in parsed_features_no_sequence.items(): + if isinstance(feature_spec, parsing_ops.FixedLenFeature): + if key == feature_keys.TrainEvalFeatures.VALUES: + parsed_features[key] = feature_spec._replace( + shape=((values_proto_length,) + + feature_spec.shape)) + else: + parsed_features[key] = feature_spec._replace( + shape=((filtering_length + prediction_length,) + + feature_spec.shape)) + elif feature_spec.dtype == dtypes.string: + parsed_features[key] = parsing_ops.FixedLenFeature( + shape=(filtering_length + prediction_length,), + dtype=dtypes.string) + else: # VarLenFeature + raise ValueError("VarLenFeatures not supported, got %s for key %s" + % (feature_spec, key)) + tfexamples = array_ops.placeholder( + shape=[default_batch_size], dtype=dtypes.string, name="input") + features = parsing_ops.parse_example( + serialized=tfexamples, + features=parsed_features) + features[feature_keys.TrainEvalFeatures.TIMES] = array_ops.squeeze( + features[feature_keys.TrainEvalFeatures.TIMES], axis=-1) + features[feature_keys.TrainEvalFeatures.VALUES] = math_ops.cast( + features[feature_keys.TrainEvalFeatures.VALUES], + dtype=self._model.dtype)[:, :filtering_length] + features.update( + self._model_start_state_placeholders( + batch_size_tensor=array_ops.shape( + features[feature_keys.TrainEvalFeatures.TIMES])[0], + static_batch_size=default_batch_size)) + return export_lib.ServingInputReceiver( + features, {"examples": tfexamples}) + return _serving_input_receiver_fn + + def build_raw_serving_input_receiver_fn( + self, default_batch_size=None, default_series_length=None): """Build an input_receiver_fn for export_savedmodel which accepts arrays. Automatically creates placeholders for exogenous `FeatureColumn`s passed to @@ -149,34 +275,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator): + batch_only_feature_shape[1:]) placeholders[feature_key] = array_ops.placeholder( dtype=value_dtype, name=feature_key, shape=feature_shape) - # Models may not know the shape of their state without creating some - # variables/ops. Avoid polluting the default graph by making a new one. We - # use only static metadata from the returned Tensors. - with ops.Graph().as_default(): - self._model.initialize_graph() - # Evaluate the initial state as same-dtype "zero" values. These zero - # constants aren't used, but are necessary for feeding to - # placeholder_with_default for the "cold start" case where state is not - # fed to the model. - def _zeros_like_constant(tensor): - return tensor_util.constant_value(array_ops.zeros_like(tensor)) - start_state = nest.map_structure( - _zeros_like_constant, self._model.get_start_state()) batch_size_tensor = array_ops.shape(time_placeholder)[0] - for prefixed_state_name, state in ts_head_lib.state_to_dictionary( - start_state).items(): - state_shape_with_batch = tensor_shape.TensorShape( - (default_batch_size,)).concatenate(state.shape) - default_state_broadcast = array_ops.tile( - state[None, ...], - multiples=array_ops.concat( - [batch_size_tensor[None], - array_ops.ones(len(state.shape), dtype=dtypes.int32)], - axis=0)) - placeholders[prefixed_state_name] = array_ops.placeholder_with_default( - input=default_state_broadcast, - name=prefixed_state_name, - shape=state_shape_with_batch) + placeholders.update( + self._model_start_state_placeholders( + batch_size_tensor, static_batch_size=default_batch_size)) return export_lib.ServingInputReceiver(placeholders, placeholders) return _serving_input_receiver_fn diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index 8686a803e5bb023bbddb7df3203080fee0e13fea..32194e400e6ada594ef2a067bf612826a6e4acd3 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -26,6 +26,7 @@ from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.export import export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -180,7 +181,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce return math_ops.cast(value, self.model.dtype) if name == feature_keys.PredictionFeatures.STATE_TUPLE: return value # Correct dtypes are model-dependent - return ops.convert_to_tensor(value) + return sparse_tensor.convert_to_tensor_or_sparse_tensor(value) def _gather_state(self, features): """Returns `features` with state packed, indicates if packing was done.""" @@ -202,6 +203,29 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce flat_sequence=[tensor for _, _, tensor in numbered_state]) return features, True + def _check_predict_features(self, features): + """Raises errors if features are not suitable for prediction.""" + if feature_keys.PredictionFeatures.TIMES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.TIMES)) + if feature_keys.PredictionFeatures.STATE_TUPLE not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.STATE_TUPLE)) + times_feature = features[feature_keys.PredictionFeatures.TIMES] + if not times_feature.get_shape().is_compatible_with([None, None]): + raise ValueError( + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) + _check_feature_shapes_compatible_with( + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + # Model-dependent shapes + feature_keys.PredictionFeatures.STATE_TUPLE + ])) + def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" with ops.name_scope(self._name, "head"): @@ -230,7 +254,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce mode == estimator_lib.ModeKeys.EVAL): _check_train_eval_features(features, self.model) elif mode == estimator_lib.ModeKeys.PREDICT: - _check_predict_features(features) + self._check_predict_features(features) else: raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode)) @@ -267,6 +291,44 @@ class OneShotPredictionHead(TimeSeriesRegressionHead): each time predictions are requested when using this head. """ + def _check_predict_features(self, features): + """Raises errors if features are not suitable for one-shot prediction.""" + if feature_keys.PredictionFeatures.TIMES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.TIMES)) + if feature_keys.TrainEvalFeatures.VALUES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.TrainEvalFeatures.VALUES)) + if feature_keys.PredictionFeatures.STATE_TUPLE not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.STATE_TUPLE)) + times_feature = features[feature_keys.PredictionFeatures.TIMES] + if not times_feature.get_shape().is_compatible_with([None, None]): + raise ValueError( + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) + _check_feature_shapes_compatible_with( + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + # Model-dependent shapes + feature_keys.PredictionFeatures.STATE_TUPLE, + # One shot prediction head relies on values being shorter than + # times. Even though we're predicting eventually, we need values for + # the filtering phase. + feature_keys.TrainEvalFeatures.VALUES, + ])) + + def _evaluate_ops(self, features): + """Add ops for evaluation (aka filtering) to the graph.""" + spec = super(OneShotPredictionHead, self)._evaluate_ops(features) + # No state is fed to OneShotPredictionHead, so we don't return it; it being + # a tuple can cause issues for downstream infrastructure. + del spec.eval_metric_ops[feature_keys.State.STATE_TUPLE] + return spec + def _serving_ops(self, features): """Add ops for serving to the graph.""" with variable_scope.variable_scope("model", use_resource=True): @@ -333,29 +395,6 @@ def _check_feature_shapes_compatible_with(features, times_shape=compatible_with_value.get_shape())) -def _check_predict_features(features): - """Raises errors if features are not suitable for prediction.""" - if feature_keys.PredictionFeatures.TIMES not in features: - raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.TIMES)) - if feature_keys.PredictionFeatures.STATE_TUPLE not in features: - raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.STATE_TUPLE)) - times_feature = features[feature_keys.PredictionFeatures.TIMES] - if not times_feature.get_shape().is_compatible_with([None, None]): - raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, - times_feature.get_shape())) - _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.PredictionFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes - ])) - - def _check_train_eval_features(features, model): """Raise errors if features are not suitable for training/evaluation.""" if feature_keys.TrainEvalFeatures.TIMES not in features: diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index 78c2cec21cf4b6ccf6c314e54de41f3e95466adf..bda3b53aca0d0156e542e2bedcadf5caa6b3d2cf 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import os from absl.testing import parameterized @@ -26,12 +27,14 @@ import six from tensorflow.contrib.estimator.python.estimator import extenders from tensorflow.contrib.timeseries.examples import lstm as lstm_example +from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib from tensorflow.contrib.timeseries.python.timeseries import input_pipeline from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import state_management +from tensorflow.core.example import example_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.estimator import estimator_lib @@ -343,15 +346,33 @@ def _structural_ensemble_regressor( model_dir=model_dir) +def _ar_lstm_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.TimeSeriesRegressor( + model=ar_model.ARModel( + periodicities=10, input_window_size=10, output_window_size=6, + num_features=5, + exogenous_feature_columns=exogenous_feature_columns, + prediction_model_factory=functools.partial( + ar_model.LSTMPredictionModel, + num_units=10)), + head_type=head_type, + model_dir=model_dir) + + class OneShotTests(parameterized.TestCase): @parameterized.named_parameters( + {"testcase_name": "ar_lstm_regressor", + "estimator_factory": _ar_lstm_regressor}, {"testcase_name": "custom_time_series_regressor", "estimator_factory": _custom_time_series_regressor}, {"testcase_name": "structural_ensemble_regressor", "estimator_factory": _structural_ensemble_regressor}) def test_one_shot_prediction_head_export(self, estimator_factory): - model_dir = os.path.join(test.get_temp_dir(), str(ops.uid())) + def _new_temp_dir(): + return os.path.join(test.get_temp_dir(), str(ops.uid())) + model_dir = _new_temp_dir() categorical_column = feature_column.categorical_column_with_hash_bucket( key="categorical_exogenous_feature", hash_bucket_size=16) exogenous_feature_columns = [ @@ -376,8 +397,10 @@ class OneShotTests(parameterized.TestCase): input_pipeline.NumpyReader(train_features), shuffle_seed=2, num_threads=1, batch_size=16, window_size=16) estimator.train(input_fn=train_input_fn, steps=5) + result = estimator.evaluate(input_fn=train_input_fn, steps=1) + self.assertNotIn(feature_keys.State.STATE_TUPLE, result) input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() - export_location = estimator.export_savedmodel(test.get_temp_dir(), + export_location = estimator.export_savedmodel(_new_temp_dir(), input_receiver_fn) graph = ops.Graph() with graph.as_default(): @@ -412,6 +435,41 @@ class OneShotTests(parameterized.TestCase): in predict_signature.outputs.items()} output = session.run(fetches, feed_dict=feeds) self.assertEqual((2, 15, 5), output["mean"].shape) + # Build a parsing input function, then make a tf.Example for it to parse. + export_location = estimator.export_savedmodel( + _new_temp_dir(), + estimator.build_one_shot_parsing_serving_input_receiver_fn( + filtering_length=20, prediction_length=15)) + graph = ops.Graph() + with graph.as_default(): + with session_lib.Session() as session: + example = example_pb2.Example() + times = example.features.feature[feature_keys.TrainEvalFeatures.TIMES] + values = example.features.feature[feature_keys.TrainEvalFeatures.VALUES] + times.int64_list.value.extend(range(35)) + for i in range(20): + values.float_list.value.extend( + [float(i) * 2. + feature_number + for feature_number in range(5)]) + real_feature = example.features.feature["2d_exogenous_feature"] + categortical_feature = example.features.feature[ + "categorical_exogenous_feature"] + for i in range(35): + real_feature.float_list.value.extend([1, 1]) + categortical_feature.bytes_list.value.append(b"strkey") + # Serialize the tf.Example for feeding to the Session + examples = [example.SerializeToString()] * 2 + signatures = loader.load( + session, [tag_constants.SERVING], export_location) + predict_signature = signatures.signature_def[ + feature_keys.SavedModelLabels.PREDICT] + ((_, input_value),) = predict_signature.inputs.items() + feeds = {graph.as_graph_element(input_value.name): examples} + fetches = {output_key: graph.as_graph_element(output_value.name) + for output_key, output_value + in predict_signature.outputs.items()} + output = session.run(fetches, feed_dict=feeds) + self.assertEqual((2, 15, 5), output["mean"].shape) if __name__ == "__main__": diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 643a7cc13aaa465e13fc05425b36369ccd547dfb..1669f6050e7ca92d973a75258a6b57bc62facff2 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -15,6 +15,7 @@ package( default_visibility = [ "//cloud/vmm/testing/tests/tpu:__subpackages__", "//learning/brain:__subpackages__", + "//learning/deepmind:__subpackages__", "//tensorflow:__subpackages__", ], ) @@ -40,13 +41,13 @@ py_library( "python/tpu/tpu_config.py", "python/tpu/tpu_context.py", "python/tpu/tpu_estimator.py", - "python/tpu/tpu_system_metadata.py", "python/tpu/util.py", ], srcs_version = "PY2AND3", deps = [ ":tpu_lib", - ":tpu_py", + "//tensorflow/compiler/xla/experimental/xla_sharding", + "//tensorflow/compiler/xla/python_api:xla_shape", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -133,7 +134,7 @@ py_library( tf_custom_op_py_library( name = "tpu_py", - srcs = glob(["python/ops/*.py"]) + ["__init__.py"], + srcs = glob(["python/ops/*.py"]), dso = [":python/ops/_tpu_ops.so"], kernels = [ ":all_ops", @@ -152,9 +153,13 @@ tf_custom_op_py_library( py_library( name = "tpu", - srcs = ["python/tpu/__init__.py"], + srcs = [ + "__init__.py", + "python/tpu/__init__.py", + ], srcs_version = "PY2AND3", deps = [ + ":keras_support", # split out to avoid cycle with tpu_strategy ":tpu_estimator", ":tpu_lib", ], @@ -169,19 +174,13 @@ py_library( visibility = [ "//cloud/vmm/testing/tests/tpu:__subpackages__", "//learning/brain:__subpackages__", - # TODO(b/111651964): Clean special visibility for keras_support. - # - # Note: If you are an end user, please do not add your project to this - # visibility. This feature is experimental, and will be made public - # when ready. - "//third_party/cloud_tpu/models/keras:__subpackages__", "//tensorflow:__subpackages__", + "//third_party/cloud_tpu/models/keras:__subpackages__", ], deps = [ ":tpu_lib", - ":tpu_py", "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", - "//tensorflow/contrib/distribute/python:tpu_strategy", + "//tensorflow/contrib/distribute", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/core:protos_all_py", @@ -217,6 +216,7 @@ py_library( "python/tpu/tpu_function.py", "python/tpu/tpu_optimizer.py", "python/tpu/tpu_sharding.py", + "python/tpu/tpu_system_metadata.py", "python/tpu/training_loop.py", ], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index d5484e9032fb874e9f608ec398be4cd03b2aaf32..537d94b7979af3e4bd3fb7392c8dcc5a210e98af 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -18,6 +18,10 @@ @@cross_replica_sum @@infeed_dequeue @@infeed_dequeue_tuple +@@infeed_enqueue +@@infeed_enqueue_tuple +@@outfeed_dequeue +@@outfeed_dequeue_tuple @@outfeed_enqueue @@outfeed_enqueue_tuple @@ -47,6 +51,9 @@ @@InputPipelineConfig @@TPUConfig @@bfloat16_scope + +@@TPUDistributionStrategy +@@keras_to_tpu_model """ from __future__ import absolute_import @@ -58,11 +65,13 @@ from tensorflow.contrib.tpu.python import profiler from tensorflow.contrib.tpu.python.ops.tpu_ops import * from tensorflow.contrib.tpu.python.tpu.bfloat16 import * from tensorflow.contrib.tpu.python.tpu.device_assignment import * +from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model +from tensorflow.contrib.tpu.python.tpu.keras_support import TPUDistributionStrategy from tensorflow.contrib.tpu.python.tpu.topology import * from tensorflow.contrib.tpu.python.tpu.tpu import * from tensorflow.contrib.tpu.python.tpu.tpu_config import * from tensorflow.contrib.tpu.python.tpu.tpu_estimator import * -from tensorflow.contrib.tpu.python.tpu.tpu_feed import * +from tensorflow.contrib.tpu.python.tpu.tpu_feed import InfeedQueue from tensorflow.contrib.tpu.python.tpu.tpu_optimizer import * from tensorflow.contrib.tpu.python.tpu.training_loop import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index f80f5652af79d410946971573ae160fdd0b85f6d..8e6e9aa0cded630f39bfd699def37e06a8b920e8 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -84,8 +84,6 @@ ProfileRequest PopulateProfileRequest(int duration_ms, request.add_tools("memory_viewer"); request.add_tools("overview_page"); *request.mutable_opts() = opts; - std::cout << "Limiting the number of trace events to " << kMaxEvents - << std::endl; return request; } @@ -99,7 +97,6 @@ bool Profile(const string& service_addr, const string& logdir, int duration_ms, ::grpc::ClientContext context; ::grpc::ChannelArguments channel_args; - // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available. // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their // `ValidateHostPortPair` checks for empty host string case. channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, @@ -166,6 +163,85 @@ bool NewSession(const string& service_addr, return new_session_response.empty_trace(); } +// Starts tracing on a single or multiple TPU hosts and saves the result in the +// given logdir. If no trace was collected, retries tracing for +// num_tracing_attempts. +void StartTracing(const tensorflow::string& service_addr, + const tensorflow::string& logdir, + const tensorflow::string& workers_list, + bool include_dataset_ops, int duration_ms, + int num_tracing_attempts) { + // Use the current timestamp as the run name. + tensorflow::string session_id = GetCurrentTimeStampAsString(); + constexpr char kProfilePluginDirectory[] = "plugins/profile/"; + tensorflow::string repository_root = + io::JoinPath(logdir, kProfilePluginDirectory); + std::vector hostnames = + tensorflow::str_util::Split(workers_list, ","); + + bool empty_trace = false; + int remaining_attempts = num_tracing_attempts; + tensorflow::ProfileOptions opts; + opts.set_include_dataset_ops(include_dataset_ops); + while (true) { + std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " + << "Remaining attempt(s): " << remaining_attempts-- << std::endl; + if (hostnames.empty()) { + empty_trace = tensorflow::tpu::Profile(service_addr, logdir, duration_ms, + repository_root, session_id, opts); + } else { + tensorflow::string tpu_master = service_addr; + empty_trace = + tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, + repository_root, session_id, opts); + } + if (remaining_attempts <= 0 || !empty_trace) break; + std::cout << "No trace event is collected. Automatically retrying." + << std::endl + << std::endl; + } + + if (empty_trace) { + std::cout << "No trace event is collected after " << num_tracing_attempts + << " attempt(s). " + << "Perhaps, you want to try again (with more attempts?)." + << std::endl + << "Tip: increase number of attempts with --num_tracing_attempts." + << std::endl; + } +} + +MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) { + MonitorRequest request; + request.set_duration_ms(duration_ms); + request.set_monitoring_level(monitoring_level); + return request; +} + +// Repeatedly collects profiles and shows user-friendly metrics for +// 'num_queries' time(s). +void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, + int monitoring_level, int num_queries) { + for (int query = 0; query < num_queries; ++query) { + MonitorRequest request = + PopulateMonitorRequest(duration_ms, monitoring_level); + + ::grpc::ClientContext context; + ::grpc::ChannelArguments channel_args; + channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, + std::numeric_limits::max()); + std::unique_ptr stub = + TPUProfiler::NewStub(::grpc::CreateCustomChannel( + "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), + channel_args)); + MonitorResponse response; + TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response))); + + std::cout << "Xprof Monitoring Results (Sample " << query + 1 << "):\n\n" + << response.data() << std::flush; + } +} + } // namespace } // namespace tpu } // namespace tensorflow @@ -174,9 +250,11 @@ int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; tensorflow::string FLAGS_workers_list; - int FLAGS_duration_ms = 2000; + int FLAGS_duration_ms = 0; int FLAGS_num_tracing_attempts = 3; bool FLAGS_include_dataset_ops = true; + int FLAGS_monitoring_level = 0; + int FLAGS_num_queries = 100; std::vector flag_list = { tensorflow::Flag("service_addr", &FLAGS_service_addr, "Address of TPU profiler service e.g. localhost:8466"), @@ -186,21 +264,38 @@ int main(int argc, char** argv) { tensorflow::Flag("logdir", &FLAGS_logdir, "Path of TensorBoard log directory e.g. /tmp/tb_log, " "gs://tb_bucket"), - tensorflow::Flag("duration_ms", &FLAGS_duration_ms, - "Duration of tracing in ms. Default is 2000ms."), + tensorflow::Flag( + "duration_ms", &FLAGS_duration_ms, + "Duration of tracing or monitoring in ms. Default is 2000ms for " + "tracing and 1000ms for monitoring."), tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts, "Automatically retry N times when no trace event " "is collected. Default is 3."), tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops, "Set to false to profile longer TPU device traces."), - }; + tensorflow::Flag("monitoring_level", &FLAGS_monitoring_level, + "Choose a monitoring level between 1 and 2 to monitor " + "your TPU job continuously. Level 2 is more verbose " + "than level 1 and shows more metrics."), + tensorflow::Flag("num_queries", &FLAGS_num_queries, + "This script will run monitoring for num_queries before " + "it stops.")}; std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION << std::endl; tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) { + if (!parse_ok || FLAGS_service_addr.empty() || + (FLAGS_logdir.empty() && FLAGS_monitoring_level == 0)) { + // Fail if flags are not parsed correctly or service_addr not provided. + // Also, fail if neither logdir is provided (required for tracing) nor + // monitoring level is provided (required for monitoring). + std::cout << usage.c_str() << std::endl; + return 2; + } + if (FLAGS_monitoring_level < 0 || FLAGS_monitoring_level > 2) { + // Invalid monitoring level. std::cout << usage.c_str() << std::endl; return 2; } @@ -213,52 +308,27 @@ int main(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); - // Sets the minimum duration_ms and tracing attempts to one. - int duration_ms = std::max(FLAGS_duration_ms, 1); - int remaining_attempts = std::max(FLAGS_num_tracing_attempts, 1); - tensorflow::ProfileOptions opts; - opts.set_include_dataset_ops(FLAGS_include_dataset_ops); - tensorflow::ProfileResponse response; - - // Use the current timestamp as the run name. - tensorflow::string session_id = - tensorflow::tpu::GetCurrentTimeStampAsString(); - constexpr char kProfilePluginDirectory[] = "plugins/profile/"; - tensorflow::string repository_root = - ::tensorflow::io::JoinPath(FLAGS_logdir, kProfilePluginDirectory); - std::vector hostnames = - tensorflow::str_util::Split(FLAGS_workers_list, ","); - - bool empty_trace = false; - while (true) { - std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " - << "Remaining attempt(s): " << remaining_attempts-- << std::endl; - if (hostnames.empty()) { - empty_trace = tensorflow::tpu::Profile(FLAGS_service_addr, FLAGS_logdir, - duration_ms, repository_root, - session_id, opts); - } else { - tensorflow::string tpu_master = FLAGS_service_addr; - empty_trace = - tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, - repository_root, session_id, opts); - } - if (remaining_attempts <= 0 || !empty_trace) break; - std::cout << "No trace event is collected. Automatically retrying." - << std::endl - << std::endl; + // Sets the minimum duration_ms, tracing attempts and num queries. + int duration_ms = std::max(FLAGS_duration_ms, 0); + if (duration_ms == 0) { + // If profiling duration was not set by user or set to a negative value, we + // set it to default values of 2000ms for tracing and 1000ms for monitoring. + duration_ms = FLAGS_monitoring_level == 0 ? 2000 : 1000; } + int num_tracing_attempts = std::max(FLAGS_num_tracing_attempts, 1); + int num_queries = std::max(FLAGS_num_queries, 1); - if (empty_trace) { - std::cout << "No trace event is collected after " - << FLAGS_num_tracing_attempts << " attempt(s). " - << "Perhaps, you want to try again (with more attempts?)." - << std::endl - << "Tip: increase number of attempts with --num_tracing_attempts." + if (FLAGS_monitoring_level != 0) { + std::cout << "Since monitoring level is provided, profile " + << FLAGS_service_addr << " for " << duration_ms + << "ms and show metrics for " << num_queries << " time(s)." << std::endl; - // Don't dump profile data if no trace is collected. - return 0; + tensorflow::tpu::StartMonitoring(FLAGS_service_addr, duration_ms, + FLAGS_monitoring_level, num_queries); + } else { + tensorflow::tpu::StartTracing(FLAGS_service_addr, FLAGS_logdir, + FLAGS_workers_list, FLAGS_include_dataset_ops, + duration_ms, num_tracing_attempts); } - return 0; } diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 7a5d01cca42351f6d4d8b41d43756560ce7874d3..438f4428483a86b75ca1feb31d9c43f860fcc287 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -50,7 +50,8 @@ flags.DEFINE_string( flags.DEFINE_string( 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' 'gs://tb_bucket') -flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') +flags.DEFINE_integer('duration_ms', 0, + 'Duration of tracing or monitoring in ms.') flags.DEFINE_integer( 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' 'event is collected.') @@ -58,6 +59,14 @@ flags.DEFINE_boolean('include_dataset_ops', True, 'Set to false to profile longer TPU ' 'device traces.') +# Monitoring parameters +flags.DEFINE_integer( + 'monitoring_level', 0, 'Choose a monitoring level between ' + '1 and 2 to monitor your TPU job continuously.') +flags.DEFINE_integer( + 'num_queries', 100, + 'This script will run monitoring for num_queries before it stops.') + FLAGS = flags.FLAGS EXECUTABLE = 'data/capture_tpu_profile' JOB_NAME = 'worker' @@ -118,6 +127,8 @@ def main(unused_argv=None): cmd.append('--duration_ms=' + str(FLAGS.duration_ms)) cmd.append('--num_tracing_attempts=' + str(FLAGS.num_tracing_attempts)) cmd.append('--include_dataset_ops=' + str(FLAGS.include_dataset_ops).lower()) + cmd.append('--monitoring_level=' + str(FLAGS.monitoring_level)) + cmd.append('--num_queries=' + str(FLAGS.num_queries)) subprocess.call(cmd) diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index f0fca63db0bca80cdaa27e491b2a03ae2246c007..da4a95e0450a9d0c20593ca60b69f3ad467d455d 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -11,6 +11,9 @@ service TPUProfiler { // Starts a profiling session, blocks until it completes, and returns data. rpc Profile(ProfileRequest) returns (ProfileResponse) { } + // Collects profiling data and returns user-friendly metrics. + rpc Monitor(MonitorRequest) returns (MonitorResponse) { + } } message ProfileOptions { @@ -104,3 +107,26 @@ message ProfileResponse { // next-field: 8 } + +message MonitorRequest { + // Duration for which to profile between each update. + uint64 duration_ms = 1; + + // Indicates the level at which we want to monitor. Currently, two levels are + // supported: + // Level 1: An ultra lightweight mode that captures only some utilization + // metrics. + // Level 2: More verbose than level 1. Collects utilization metrics, device + // information, step time information, etc. Do not use this option if the TPU + // host is being very heavily used. + int32 monitoring_level = 2; + + // next-field: 3 +} + +message MonitorResponse { + // Properly formatted string data that can be directly returned back to user. + string data = 1; + + // next-field: 2 +} diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index 9150606f5eb254a824f4632df53329cbf2edf570..2cc17d6d928370afbb0e3b1e89252f7a687c27d3 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -1,10 +1,12 @@ -syntax = "proto2"; +syntax = "proto3"; package tensorflow.tpu; +import "google/protobuf/wrappers.proto"; + message ClippingLimits { - optional float lower = 1 [default = -inf]; - optional float upper = 2 [default = inf]; + google.protobuf.FloatValue lower = 1; // -inf if not set + google.protobuf.FloatValue upper = 2; // +inf if not set } // Get the learning rate from a source that can change @@ -21,18 +23,18 @@ message LearningRate { } message AdagradParameters { - optional float initial_accumulator = 1 [default = 0.]; + float initial_accumulator = 1; } message StochasticGradientDescentParameters { } message FtrlParameters { - optional float l1 = 1 [default = 0.]; - optional float l2 = 2 [default = 0.]; - optional float lr_power = 3 [default = 0.]; - optional float initial_accum = 4 [default = 0.]; - optional float initial_linear = 5 [default = 0.]; + float l1 = 1; + float l2 = 2; + float lr_power = 3; + float initial_accum = 4; + float initial_linear = 5; } // The Adam optimizer does not implement hyper-parameter update; use the dynamic @@ -41,84 +43,84 @@ message FtrlParameters { // Here, t is the current timestep. // https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54 message AdamParameters { - optional float beta1 = 3 [default = 0.]; - optional float beta2 = 4 [default = 0.]; - optional float epsilon = 5 [default = 0.]; - optional float initial_m = 6 [default = 0.]; - optional float initial_v = 7 [default = 0.]; + float beta1 = 3; + float beta2 = 4; + float epsilon = 5; + float initial_m = 6; + float initial_v = 7; } message MomentumParameters { - optional float momentum = 1 [default = 0.]; - optional bool use_nesterov = 2 [default = false]; - optional float initial_accum = 3 [default = 0.]; + float momentum = 1; + bool use_nesterov = 2; + float initial_accum = 3; } message RmsPropParameters { - optional float rho = 1 [default = 0.]; - optional float momentum = 2 [default = 0.]; - optional float epsilon = 3 [default = 0.]; - optional float initial_ms = 4 [default = 0.]; - optional float initial_mom = 5 [default = 0.]; + float rho = 1; + float momentum = 2; + float epsilon = 3; + float initial_ms = 4; + float initial_mom = 5; } message CenteredRmsPropParameters { - optional float rho = 1 [default = 0.]; - optional float momentum = 2 [default = 0.]; - optional float epsilon = 3 [default = 0.]; - optional float initial_ms = 4 [default = 0.]; - optional float initial_mom = 5 [default = 0.]; - optional float initial_mg = 6 [default = 0.]; + float rho = 1; + float momentum = 2; + float epsilon = 3; + float initial_ms = 4; + float initial_mom = 5; + float initial_mg = 6; } message MdlAdagradLightParameters { - optional float l2 = 1; - optional float lr_power = 2; - optional float min_servable_mdl_benefit = 3; - optional float mdl_mix_in_margin = 4; - optional float mdl_benefit_rampup_coeff = 5; - optional float mdl_min_weight = 6; - optional float benefit_revisit_scale = 7; - optional float max_event_benefit = 8; - optional float max_total_benefit = 9; - optional float mdl_hard_limit = 10; - optional bool hard_limit_min_benefit = 11; - optional bool mdl_regularize = 12; - optional float initial_accumulator = 13; - optional float initial_weight = 14; - optional float initial_benefit = 15; + float l2 = 1; + float lr_power = 2; + float min_servable_mdl_benefit = 3; + float mdl_mix_in_margin = 4; + float mdl_benefit_rampup_coeff = 5; + float mdl_min_weight = 6; + float benefit_revisit_scale = 7; + float max_event_benefit = 8; + float max_total_benefit = 9; + float mdl_hard_limit = 10; + bool hard_limit_min_benefit = 11; + bool mdl_regularize = 12; + float initial_accumulator = 13; + float initial_weight = 14; + float initial_benefit = 15; } message AdadeltaParameters { - optional float rho = 1; - optional float epsilon = 2; - optional float initial_accumulator = 3 [default = 0.]; - optional float initial_update = 4 [default = 0.]; + float rho = 1; + float epsilon = 2; + float initial_accumulator = 3; + float initial_update = 4; } message ProximalAdagradParameters { - optional float l1 = 1; - optional float l2 = 2; - optional float initial_accumulator = 3; + float l1 = 1; + float l2 = 2; + float initial_accumulator = 3; } message OptimizationParameters { // Learning rate used for updating the embedding layer parameters. - optional LearningRate learning_rate = 13; + LearningRate learning_rate = 13; reserved 1; // Old learning rate tag. // Limits to which to clip the weight values after the backward pass; not // present means no limits are applied. - optional ClippingLimits clipping_limits = 2; + ClippingLimits clipping_limits = 2; // Limits to which to clip the backward pass gradient before using it for // updates; not present means no limits are applied. - optional ClippingLimits gradient_clipping_limits = 7; + ClippingLimits gradient_clipping_limits = 7; // Whether to use gradient accumulation (do two passes over the input // gradients: one to accumulate them into a temporary array and another to // apply them using the actual optimization algorithm). - optional bool use_gradient_accumulation = 15 [default = false]; + bool use_gradient_accumulation = 15; // Optimization algorithm parameters; which field is selected determines which // algorithm to use. @@ -140,7 +142,7 @@ message OptimizationParameters { // value vector and any extra accumulators, etc.). message StateVariableSpecification { // Parameter name for the state variable. - optional string name = 1; + string name = 1; // A normal state variable that should be saved and restored in checkpoints // and used as an input or output to non-debug TensorFlow ops. @@ -151,7 +153,7 @@ message StateVariableSpecification { // from users (used for intermediate gradients being accumulated, for // example). message FillWithConstant { - optional double initial_value = 1; + double initial_value = 1; } // Usage type of this state variable. diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index 726b2d248e3086e1882004827076ed3e563d960d..471b1fa46c679dcab70e9bc12d61ada84cba79bb 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -175,6 +175,8 @@ class DeviceAssignment(object): """Returns the physical topology coordinates of a logical core.""" if logical_core is None: logical_core = np.array([0, 0, 0], np.int32) + else: + logical_core = np.asarray(logical_core) if any(logical_core < 0) or any(logical_core >= self.computation_shape): raise ValueError("Invalid core {}; computation shape is {}".format( diff --git a/tensorflow/contrib/tpu/python/tpu/error_handling.py b/tensorflow/contrib/tpu/python/tpu/error_handling.py index 14659fe68f7231433e052c2fa05efbec029f9cf6..52e1ea42370d653d1de7c12eee4b456ec7ce921c 100644 --- a/tensorflow/contrib/tpu/python/tpu/error_handling.py +++ b/tensorflow/contrib/tpu/python/tpu/error_handling.py @@ -19,10 +19,11 @@ from __future__ import division from __future__ import print_function import contextlib +import sys import threading import time -import traceback +import six from tensorflow.python.framework import errors from tensorflow.python.platform import tf_logging as logging @@ -51,7 +52,7 @@ class ErrorRendezvous(object): self._num_sources = num_sources self._session_cancel_timer = None - def record_error(self, source, exception, session=None): + def record_error(self, source, exc_info, session=None): """Report an exception from the given source. If a session is passed, a timer will be registered to close it after a few @@ -61,12 +62,12 @@ class ErrorRendezvous(object): Args: source: string, source of the error - exception: Exception being thrown + exc_info: Output from `sys.exc_info` (type, value, traceback) session: Session to close after delay. """ - logging.info('Error recorded from %s: %s', source, exception) - stack_trace = traceback.format_exc() - self._errors[source] = (exception, stack_trace) + _, value, _ = exc_info + self._errors[source] = exc_info + logging.info('Error recorded from %s: %s', source, value) if session is not None and self._session_cancel_timer is None: @@ -98,8 +99,8 @@ class ErrorRendezvous(object): """Context manager to report any errors within a block.""" try: yield - except Exception as e: # pylint: disable=broad-except - self.record_error(source, e, session) + except Exception: # pylint: disable=broad-except + self.record_error(source, sys.exc_info(), session) def raise_errors(self, timeout_sec=0): """Wait for up to `timeout` seconds for all error sources to finish. @@ -117,16 +118,15 @@ class ErrorRendezvous(object): kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None] - if not kept_errors: - return - # First check for any interesting errors, then fall back on the session # cancelled errors etc. - for k, (exc, _) in kept_errors: - if isinstance(exc, _UNINTERESTING_ERRORS): + for k, (typ, value, traceback) in kept_errors: + if isinstance(value, _UNINTERESTING_ERRORS): continue else: - raise exc + logging.warn('Reraising captured error') + six.reraise(typ, value, traceback) - for k, (exc, _) in kept_errors: - raise exc + for k, (typ, value, traceback) in kept_errors: + logging.warn('Reraising captured error') + six.reraise(typ, value, traceback) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 81798ee42313cb9e2232a4796f56d4d16068b82f..ff893a722f4e77c743edd3b8db77aa90be1e498d 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -55,7 +55,6 @@ import time import numpy as np from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver -from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops @@ -82,7 +81,11 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name + +# Work-around dependency cycle between DistributionStrategy and TPU lib. +def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name + from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top + return tpu_strategy.TPUStrategy(*args, **kw) class TPUEmbedding(embeddings.Embedding): @@ -1130,7 +1133,7 @@ Output shape: %(output_shape)s 'layer': layer, 'input_shape': layer.input_shape, 'output_shape': layer.output_shape - }) + }) @experimental diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 06885bbc25801d00dcad341529cddef97b828b88..7fa06d6d560a4b6ffa6d9a3fd0fa208b4c60ee7f 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -314,7 +314,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # Capture the device function stack at the time of first entry # since that is the stack that will be used outside_compilation. graph = ops.get_default_graph() - self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access + # pylint: disable=protected-access + self._outer_device_function_stack = graph._device_function_stack.copy() + # pylint: enable=protected-access super(TPUReplicateContext, self).Enter() def HostComputeCore(self): @@ -968,8 +970,15 @@ def rewrite(computation, Args: computation: A Python function that builds a computation to apply to the input. If the function takes n inputs, 'inputs' should be - a list of n tensors. If the function returns m outputs, rewrite - will return a list of m tensors. + a list of n tensors. + + `computation` may return a list of operations and tensors. Tensors must + come before operations in the returned list. The return value of + `rewrite` is a list of tensors corresponding to the tensors from the + from `computation`. + + All `Operation`s returned from `computation` will be executed when + evaluating any of the returned output tensors. inputs: A list of input tensors or `None` (equivalent to an empty list). infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. @@ -1006,6 +1015,19 @@ _BLACKLISTED_INFERENCE_OPS = set([ ]) +def under_tpu_inference_context(): + """Check if it is currently under `tpu.rewrite_for_inference()`.""" + graph = ops.get_default_graph() + + context = graph._get_control_flow_context() # pylint: disable=protected-access + while context: + if isinstance(context, _TPUInferenceContext): + return True + context = context.outer_context + + return False + + class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): """A `ControlFlowContext` for nodes inside a TPU inference computation. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 9e010922dcf565e78944bd77d49f7d3fa07f2cc4..8d05e081a7c6e0327fedae6dc2c3ba45df40d029 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -44,7 +44,6 @@ class InputPipelineConfig(object): BROADCAST = 4 -# TODO(b/72511246) Provide a simplified api to configure model parallelism. class TPUConfig( collections.namedtuple('TPUConfig', [ 'iterations_per_loop', @@ -53,6 +52,7 @@ class TPUConfig( 'per_host_input_for_training', 'tpu_job_name', 'initial_infeed_sleep_secs', + 'input_partition_dims', ])): r"""TPU related configuration required by `TPUEstimator`. @@ -90,6 +90,17 @@ class TPUConfig( initial_infeed_sleep_secs: The number of seconds the infeed thread should wait before enqueueing the first batch. This helps avoid timeouts for models that require a long compilation time. + input_partition_dims: A nested list to describe the partition dims + for all the tensors from input_fn(). The structure of + input_partition_dims must match the structure of `features` and + `labels` from input_fn(). The total number of partitions must match + `num_cores_per_replica`. For example, if input_fn() returns two tensors: + images with shape [N, H, W, C] and labels [N]. + input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4 + pieces and feed into 4 TPU cores. labels tensor are directly broadcasted + to all the TPU cores since the partition dims is `None`. + Current limitations: This feature is only supported with the PER_HOST_V2 + input mode. Raises: ValueError: If `computation_shape` or `computation_shape` are invalid. @@ -101,7 +112,8 @@ class TPUConfig( num_cores_per_replica=None, per_host_input_for_training=True, tpu_job_name=None, - initial_infeed_sleep_secs=None): + initial_infeed_sleep_secs=None, + input_partition_dims=None): # Check iterations_per_loop. util_lib.check_positive_integer(iterations_per_loop, @@ -111,6 +123,20 @@ class TPUConfig( if num_shards is not None: util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') + if input_partition_dims is not None: + if len(input_partition_dims) != 1 and len(input_partition_dims) != 2: + raise ValueError( + 'input_partition_dims must be a list/tuple with one or two' + ' elements.') + + if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2: + raise ValueError( + 'input_partition_dims is only supported in PER_HOST_V2 mode.') + + if num_cores_per_replica is None: + raise ValueError( + 'input_partition_dims requires setting num_cores_per_replica.') + # Parse computation_shape if num_cores_per_replica is not None: if num_cores_per_replica not in [1, 2, 4, 8]: @@ -139,7 +165,8 @@ class TPUConfig( num_cores_per_replica=num_cores_per_replica, per_host_input_for_training=per_host_input_for_training, tpu_job_name=tpu_job_name, - initial_infeed_sleep_secs=initial_infeed_sleep_secs) + initial_infeed_sleep_secs=initial_infeed_sleep_secs, + input_partition_dims=input_partition_dims) class RunConfig(run_config_lib.RunConfig): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index a9cf54f77d8192b51af094e71707a958594874f6..806ae1c4c9918be0bf0af8579c12386c0a18aff0 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -232,11 +232,16 @@ class _InternalTPUContext(object): if tpu_system_metadata is not None: return tpu_system_metadata + cluster_def = None + if (self._config.session_config and + self._config.session_config.cluster_def.job): + cluster_def = self._config.session_config.cluster_def + # pylint: disable=protected-access tpu_system_metadata = ( tpu_system_metadata_lib._query_tpu_system_metadata( master, - run_config=self._config, + cluster_def=cluster_def, query_topology=self.model_parallelism_enabled)) self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata @@ -272,6 +277,10 @@ class _InternalTPUContext(object): def model_parallelism_enabled(self): return self._model_parallelism_enabled + @property + def input_partition_dims(self): + return self._config.tpu_config.input_partition_dims + @property def device_assignment(self): return (self._get_device_assignment() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 2c7e7d84c0efb69d5f359fddec485ec5846a5326..c104b2403c69529625bf7a6d921a952150b31b3a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -22,6 +22,7 @@ import collections import copy import os import signal +import sys import threading import time @@ -257,7 +258,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metrics=None, export_outputs=None, scaffold_fn=None, - host_call=None): + host_call=None, + training_hooks=None, + evaluation_hooks=None, + prediction_hooks=None): """Creates a validated `TPUEstimatorSpec` instance.""" host_calls = {} if eval_metrics is not None: @@ -265,6 +269,17 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote if host_call is not None: host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) + + training_hooks = list(training_hooks or []) + evaluation_hooks = list(evaluation_hooks or []) + prediction_hooks = list(prediction_hooks or []) + + for hook in training_hooks + evaluation_hooks + prediction_hooks: + if not isinstance(hook, session_run_hook.SessionRunHook): + raise TypeError( + 'All hooks must be SessionRunHook instances, given: {}'.format( + hook)) + return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, @@ -274,7 +289,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metrics=eval_metrics, export_outputs=export_outputs, scaffold_fn=scaffold_fn, - host_call=host_call) + host_call=host_call, + training_hooks=training_hooks, + evaluation_hooks=evaluation_hooks, + prediction_hooks=prediction_hooks) def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" @@ -290,6 +308,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] + hooks = list(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -299,9 +318,9 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metric_ops=eval_metric_ops, export_outputs=self.export_outputs, scaffold=scaffold, - training_hooks=hooks, - evaluation_hooks=hooks, - prediction_hooks=hooks) + training_hooks=self.training_hooks + hooks, + evaluation_hooks=self.evaluation_hooks + hooks, + prediction_hooks=self.prediction_hooks + hooks) class _OpQueueContext(object): @@ -762,16 +781,26 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( features, labels)) - control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - captured_infeed_queue.capture(infeed_queue) + if inputs_structure_recorder.flattened_input_dims: + # pylint: disable=protected-access + infeed_queue = tpu_feed._PartitionedInfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0]), + host_id=host_id, + input_partition_dims=inputs_structure_recorder.flattened_input_dims, + device_assignment=ctx.device_assignment) + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs) + else: + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0])) + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs, + tpu_ordinal_function=tpu_ordinal_function_impl) + captured_infeed_queue.capture(infeed_queue) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset @@ -790,7 +819,15 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, is_dataset = inputs.is_dataset if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - raise TypeError('Mode PREDICT not yet supported in BROADCAST mode.') + if not is_dataset: + raise TypeError( + 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' + '`features` and `labels`.') + + inputs = _InputsWithStoppingSignals( + dataset=inputs.dataset, + batch_size=ctx.batch_size_for_input_fn, + add_padding=True) if is_dataset: hooks.append(inputs.dataset_initializer_hook()) @@ -809,6 +846,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, """Generates enqueue ops for all the hosts.""" broadcasted_inputs = [] flattened_inputs = None # Cache result from input_fn. + signals = None for host_id in xrange(num_hosts): with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): for _ in xrange(ctx.num_of_replicas_per_host): @@ -818,11 +856,13 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, # hosts). if flattened_inputs is None: features, labels = inputs.features_and_labels() # Calls get_next() + signals = inputs.signals() + inputs_structure_recorder.validate_and_record_structure( - features, labels) + features, labels, signals) flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( - features, labels)) + features, labels, signals)) broadcasted_inputs.append(flattened_inputs) infeed_queue = tpu_feed.InfeedQueue( @@ -832,7 +872,14 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, broadcasted_inputs, tpu_ordinal_function=tpu_ordinal_function_impl, placement_function=device_function_impl) - return enqueue_ops + + if signals is None: + return enqueue_ops + else: + return { + 'ops': enqueue_ops, + 'signals': signals, + } return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset @@ -870,21 +917,68 @@ class _InputPipeline(object): class InputsStructureRecorder(object): """The recorder to record inputs structure.""" - def __init__(self): + def __init__(self, input_partition_dims=None): # Holds the structure of inputs self._feature_names = [] self._label_names = [] self._has_labels = False self._signals_helper = None + self._flattened_input_dims = None + + if input_partition_dims: + # This should have been validated in TPUConfig. + assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' + if len(input_partition_dims) == 2: + self._feature_dims, self._label_dims = input_partition_dims + else: + self._feature_dims = input_partition_dims[0] + self._label_dims = None + + assert self._feature_dims is not None, ('input_partition_dims[0] must ' + 'not be None') + else: + self._feature_dims = None + self._label_dims = None # Internal state. self._initialized = False + @property + def flattened_input_dims(self): + assert self._initialized, 'InputsStructureRecorder is not initialized.' + return self._flattened_input_dims + def has_labels(self): return self._has_labels + def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, + label_dims_names, label_names, has_labels): + """Flatten input dims with the same order as flattened input tensors.""" + flattened_input_dims = [] + if feature_dims_names: + # We need a fixed ordering for matching the tensors in features. + flattened_input_dims.extend( + [feature_dims[name] for name in feature_dims_names]) + else: + flattened_input_dims.append(feature_dims) + + if label_dims_names: + # We need a fixed ordering for matching the tensors in labels. + flattened_input_dims.extend( + [label_dims[name] for name in label_dims_names]) + else: + if label_names: + num_tensors_in_label = len(label_names) + else: + num_tensors_in_label = int(has_labels) + # Setting `None` in input_partition_dims[1] will apply `None` to + # all the tensors in labels, regardless of internal structure. + flattened_input_dims.extend([label_dims] * num_tensors_in_label) + + return flattened_input_dims + def validate_and_record_structure(self, features, labels, signals=None): - """Validates and records the structure of features` and `labels`.""" + """Validates and records the structure of `features` and `labels`.""" def _extract_key_names(tensor_or_dict): if tensor_or_dict is None: @@ -912,6 +1006,24 @@ class _InputPipeline(object): self._feature_names = feature_names self._label_names = label_names self._has_labels = has_labels + if self._feature_dims is not None: + feature_dims_names = _extract_key_names(self._feature_dims) + if feature_dims_names != feature_names: + raise ValueError( + 'TPUConfig.input_partition_dims[0] mismatched feature' + ' keys. Expected {}, got {}'.format(feature_names, + feature_dims_names)) + + label_dims_names = _extract_key_names(self._label_dims) + if self._label_dims is not None and label_dims_names != label_names: + raise ValueError( + 'TPUConfig.input_partition_dims[1] mismatched label' + ' keys. Expected {}, got {}'.format(label_names, + label_dims_names)) + + self._flattened_input_dims = self._flatten_input_dims( + self._feature_dims, feature_dims_names, self._label_dims, + label_dims_names, label_names, has_labels) def flatten_features_and_labels(self, features, labels, signals=None): """Flattens the `features` and `labels` to a single tensor list.""" @@ -1006,7 +1118,8 @@ class _InputPipeline(object): Raises: ValueError: If both `sharded_features` and `num_cores` are `None`. """ - self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder() + self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( + ctx.input_partition_dims) self._sharded_per_core = ctx.is_input_sharded_per_core() self._input_fn = input_fn @@ -1079,9 +1192,11 @@ class _InputPipeline(object): all_hooks.extend(hooks) if is_dataset: run_infeed_loop_on_coordinator = False - enqueue_ops.append( - _wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) + wrap_fn = ( + _wrap_computation_in_while_loop + if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else + _wrap_computation_in_while_loop_with_stopping_signals) + enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) else: enqueue_ops.append(enqueue_ops_fn()) infeed_queues.append(captured_infeed_queue.get()) @@ -1199,6 +1314,7 @@ class _ModelFnWrapper(object): host_call = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_training_hooks = _CapturedObject() def train_step(loss): """Training step function for use inside a while loop.""" @@ -1215,6 +1331,8 @@ class _ModelFnWrapper(object): else: captured_scaffold_fn.capture(None) + captured_training_hooks.capture(estimator_spec.training_hooks) + # We must run train_op to update the variables prior to running the # outfeed. with ops.control_dependencies([train_op]): @@ -1226,7 +1344,8 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_call_outfeed_ops): return array_ops.identity(loss) - return train_step, host_call, captured_scaffold_fn + return (train_step, host_call, captured_scaffold_fn, + captured_training_hooks) def convert_to_single_tpu_eval_step(self, dequeue_fn): """Converts user provided model_fn` as a single eval step on TPU. @@ -1256,6 +1375,7 @@ class _ModelFnWrapper(object): """ host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_eval_hooks = _CapturedObject() def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" @@ -1270,6 +1390,8 @@ class _ModelFnWrapper(object): loss = tpu_estimator_spec.loss captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) + captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) + to_record = {} if tpu_estimator_spec.eval_metrics: to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics @@ -1282,7 +1404,7 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_calls.create_enqueue_op()): return math_ops.add(total_loss, loss) - return eval_step, host_calls, captured_scaffold_fn + return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks def convert_to_single_tpu_predict_step(self, dequeue_fn): """Converts user provided model_fn` as a single predict step on TPU. @@ -1297,6 +1419,7 @@ class _ModelFnWrapper(object): """ host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_predict_hooks = _CapturedObject() def predict_step(unused_scalar_stopping_signal): """Evaluation step function for use inside a while loop.""" @@ -1317,6 +1440,7 @@ class _ModelFnWrapper(object): self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) + captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) to_record = {} identity_fn = lambda **kwargs: kwargs to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] @@ -1328,7 +1452,8 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_calls.create_enqueue_op()): return _StopSignals.as_scalar_stopping_signal(stopping_signals) - return predict_step, host_calls, captured_scaffold_fn + return (predict_step, host_calls, captured_scaffold_fn, + captured_predict_hooks) def _verify_tpu_spec_predictions(self, predictions): """Validates TPUEstimatorSpec.predictions dict.""" @@ -1450,11 +1575,9 @@ class _ModelFnWrapper(object): err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' if estimator_spec.training_chief_hooks: - raise ValueError(err_msg.format('training_chief_hooks')) - if estimator_spec.training_hooks: - raise ValueError(err_msg.format('training_hooks')) - if estimator_spec.evaluation_hooks: - raise ValueError(err_msg.format('evaluation_hooks')) + raise ValueError( + err_msg.format('training_chief_hooks') + 'If you want' + + ' to pass training hooks, please pass via training_hooks.') if estimator_spec.scaffold: logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' @@ -1936,10 +2059,9 @@ class TPUEstimator(estimator_lib.Estimator): """Constructs an `TPUEstimator` instance. Args: - model_fn: Model function as required by `Estimator`. For training, the - returned `EstimatorSpec` cannot have hooks as it is not supported in - `TPUEstimator`. Instead, the user can pass the training hooks as - an argument to `TPUEstimator.train()`. + model_fn: Model function as required by `Estimator` which returns + EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', + and `prediction_hooks` must not capure any TPU Tensor inside the model_fn. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in @@ -2308,8 +2430,8 @@ class TPUEstimator(estimator_lib.Estimator): input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps, saving_listeners=saving_listeners ) - except Exception as e: # pylint: disable=broad-except - rendezvous.record_error('training_loop', e) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('training_loop', sys.exc_info()) finally: rendezvous.record_done('training_loop') rendezvous.raise_errors() @@ -2323,8 +2445,8 @@ class TPUEstimator(estimator_lib.Estimator): input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path, name=name ) - except Exception as e: # pylint: disable=broad-except - rendezvous.record_error('evaluation_loop', e) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('evaluation_loop', sys.exc_info()) finally: rendezvous.record_done('evaluation_loop') rendezvous.raise_errors() @@ -2345,8 +2467,8 @@ class TPUEstimator(estimator_lib.Estimator): checkpoint_path=checkpoint_path, yield_single_examples=yield_single_examples): yield result - except Exception as e: # pylint: disable=broad-except - rendezvous.record_error('prediction_loop', e) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('prediction_loop', sys.exc_info()) finally: rendezvous.record_done('prediction_loop') rendezvous.raise_errors() @@ -2408,7 +2530,7 @@ class TPUEstimator(estimator_lib.Estimator): graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold = ( + loss, host_call, scaffold, training_hooks = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) host_ops = host_call.create_tpu_hostcall() if host_ops is None: @@ -2463,6 +2585,9 @@ class TPUEstimator(estimator_lib.Estimator): self._config.tpu_config.iterations_per_loop) hooks.append(examples_hook) + if training_hooks: + hooks.extend(training_hooks) + chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): @@ -2474,6 +2599,7 @@ class TPUEstimator(estimator_lib.Estimator): checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) chief_hooks.append(checkpoint_hook) + summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): update_ops = _sync_variables_ops() @@ -2493,7 +2619,7 @@ class TPUEstimator(estimator_lib.Estimator): scaffold=scaffold) if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold = _eval_on_tpu_system( + total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) iterations_per_loop_var = _create_or_get_iterations_per_loop() mean_loss = math_ops.div(total_loss, @@ -2537,6 +2663,9 @@ class TPUEstimator(estimator_lib.Estimator): rendezvous=self._rendezvous[mode]), ] + input_hooks + if eval_hooks: + hooks.extend(eval_hooks) + return model_fn_lib.EstimatorSpec( mode, loss=mean_loss, @@ -2547,8 +2676,9 @@ class TPUEstimator(estimator_lib.Estimator): # Predict assert mode == model_fn_lib.ModeKeys.PREDICT - dummy_predict_op, host_calls, scaffold = _predict_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) + (dummy_predict_op, host_calls, + scaffold, prediction_hooks) = _predict_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): internal_ops_to_run = _sync_variables_ops() with ops.control_dependencies(internal_ops_to_run): @@ -2604,6 +2734,9 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), ] + input_hooks + if prediction_hooks: + hooks.extend(prediction_hooks) + return model_fn_lib.EstimatorSpec( mode, prediction_hooks=hooks, @@ -2687,8 +2820,8 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_eval_step, host_calls, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) + (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks + ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) def multi_tpu_eval_steps_on_single_shard(): return training_loop.repeat( @@ -2703,15 +2836,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold + return loss, host_calls, scaffold, captured_eval_hooks.get() def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_train_step, host_call, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) + (single_tpu_train_step, host_call, captured_scaffold_fn, + captured_training_hooks) = ( + model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) def multi_tpu_train_steps_on_single_shard(): return training_loop.repeat( @@ -2726,15 +2860,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold + return loss, host_call, scaffold, captured_training_hooks.get() def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" num_cores = ctx.num_cores - single_tpu_predict_step, host_calls, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)) + (single_tpu_predict_step, host_calls, captured_scaffold_fn, + captured_predict_hooks + ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) def multi_tpu_predict_steps_on_single_shard(): @@ -2751,10 +2886,11 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=num_cores, - outputs_from_all_shards=False) + outputs_from_all_shards=False, + device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold + return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() def _wrap_computation_in_while_loop(device, op_fn): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index a44b4f4622afabced9cb1b801acedb0e7b1e5d12..d9c77a3ea1bbc456f058f36d78eec1f0843ddc79 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -20,8 +20,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding +from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding @@ -30,6 +35,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.util import nest class InfeedQueue(object): @@ -640,3 +646,264 @@ class InfeedQueue(object): tpu_ordinal=tpu_ordinal_function(index)) for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) ] + + +class _PartitionedInfeedQueue(InfeedQueue): + """A helper object to build a device infeed queue with input partition. + + Args: + number_of_tuple_elements: the number of Tensors fed atomically through the + queue, must be present unless it can be inferred from other arguments. + device_assignment: A TPU `DeviceAssignment` which is used to place all the + partitions to different TPU infeed queues. + host_id: The id of the host machine. + input_partition_dims: A nested list/tuple of integers. Each inner + list/tuple describes how to partition the corresponding input tensor. + tuple_types: If not None, a list of types of the elements of the queue. + tuple_shapes: If not None, a list of shapes of the elements of the queue. + name: The name of the queue. + """ + + def __init__(self, + number_of_tuple_elements, + device_assignment, + host_id, + input_partition_dims=None, + tuple_types=None, + tuple_shapes=None, + name=None): + super(_PartitionedInfeedQueue, self).__init__( + number_of_tuple_elements=number_of_tuple_elements, + tuple_types=tuple_types, + tuple_shapes=None, + shard_dimensions=None, + name="PartitionedInfeedQueue" if name is None else name) + self._input_partition_dims = input_partition_dims + self._host_id = host_id + self._device_assignment = device_assignment + + def generate_dequeue_op(self, tpu_device=0): + """Generate TPU dequeue ops. + + Args: + tpu_device: The TPU device ordinal where the infeed instruction should be + placed. + + Returns: + A list of Outputs corresponding to a partition of infeed dequeued + into XLA, suitable for use within a replicated block. + + Raises: + ValueError: if the types or shapes of the tuple elements have not been + set; or if a dequeue op has already been generated. + """ + self.freeze() + if self._generated_dequeue_op: + raise ValueError("Can't generate two dequeue Ops from the same queue") + self._generated_dequeue_op = True + full_name = "%s/dequeue" % self._name + sharded_shapes = [ + policy.get_sharded_shape(shape) + for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) + ] + with ops.device(tpu.core(tpu_device)): + values = tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + return self._tag_sharding_attribute_for_dequeued_tensors( + values, self._input_partition_dims) + + def generate_enqueue_ops(self, per_host_sharded_inputs): + """Generates the host-side Ops to enqueue the partitioned inputs. + + per_host_sharded_inputs is a list, one for each replica, of lists of + Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed + replica i. + sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. + + For example, if sharded_inputs[i][j] is a 2-D Tensor: + [[A, B, C, D], + [E ,F, G, H]] + self._input_partition_dims[j] is [2, 4]. + + sharded_inputs[i][j] will be partitioned and flattened into: + [A, B, C, D, E, F, G, H] and fed into the logical core ids: + [0, 1, 2, 3, 4, 5, 6, 7] respectively. + + Args: + per_host_sharded_inputs: a list of lists of Tensors. The length of the + outer list determines the number of shards. Each inner list indicates + the types and shapes of the tuples in the corresponding shard. + + Returns: + A list of host-side Ops, one for each shard, that when executed together + will enqueue a full-size element of infeed. + + Raises: + ValueError: if the queue configuration has previously been frozen and the + shapes of the elements of sharded_inputs are not compatible with the + frozen configuration; or if the shapes of the elements of sharded_inputs + don't form a consistent unsharded tuple; or if the elements of a tuple + have different device constraints; or if the partition dims are invalid. + TypeError: if the queue configuration has previously been frozen and the + types of the elements of sharded_inputs are not compatible with the + frozen configuration; or if the types of the elements of sharded_inputs + don't form a consistent unsharded tuple. + """ + self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs) + number_of_replicas_per_host = len(per_host_sharded_inputs) + number_of_tuple_elements = len(per_host_sharded_inputs[0]) + + assert len(self._input_partition_dims) == number_of_tuple_elements + per_host_enqueue_ops = [] + + for replica_index in range(number_of_replicas_per_host): + flattened_inputs = per_host_sharded_inputs[replica_index] + inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, + self._input_partition_dims) + inputs_parted_iters = [ + iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in + zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat) + ] + + for core_index in xrange(self._device_assignment.num_cores_per_replica): + # Places different partitions to different logic cores. + logical_core = self._get_logical_core(core_index) + replica_id = self._device_assignment.lookup_replicas( + self._host_id, logical_core)[replica_index] + ordinal = self._device_assignment.tpu_ordinal( + replica=replica_id, logical_core=logical_core) + infeed_inputs = [] + for it in inputs_parted_iters: + input_for_device = next(it, None) + if input_for_device is not None: + infeed_inputs.append(input_for_device) + + if infeed_inputs: + per_host_enqueue_ops.append( + tpu_ops.infeed_enqueue_tuple( + inputs=infeed_inputs, + shapes=[x.shape for x in infeed_inputs], + name="enqueue/replica_{0}/input_{1}".format( + replica_index, core_index), + device_ordinal=ordinal)) + return per_host_enqueue_ops + + def _check_input_partition_dims(self, tensor, dims): + """Checks that input partition dims are valid for the `Tensor`. + + Args: + tensor: Input tensor for partitioning. + dims: A list of integer describes how to partition the input tensor. + + Raises: + ValueError: If the tensor can't be partitioned by dims or the + num_cores_per_replica doesn't match the number of + partitions(dims.prod()). + """ + if dims is None: + return + + dims = np.array(dims) + + if (dims < 1).any(): + raise ValueError("All input partition dims must be >= 1.") + + # No partitioning, so don't perform further checks. + if dims.prod() == 1: + return + + if dims.prod() != self._device_assignment.num_cores_per_replica: + raise ValueError( + "The product of each input parition dim should equal to " + "num_cores_per_replica. (dim = {}, num_cores_per_replica " + "= {})".format(dims, self._device_assignment.num_cores_per_replica)) + if dims.shape[0] != tensor.shape.ndims: + raise ValueError( + "Input partition dims must have the same number of dimensions " + "as the `Tensor` to be partitioned. (tensor shape = {}, input " + "partition dims = {}).".format(tensor.shape.as_list(), dims)) + + tensor.shape.assert_is_fully_defined() + if (np.array(tensor.shape.as_list()) % dims != 0).any(): + raise ValueError( + "All input partition dims must divide exactly into the `Tensor` " + "shape (tensor shape = {}, input partition dims = {}).".format( + tensor.shape.as_list(), dims)) + + def _partition_or_replicate_on_host(self, tensor, dims): + """Partitions or replicates the input tensor. + + The ops inside this function are placed on the host side. + + Args: + tensor: The input tensor which will be partioned or replicated. + dims: A list of integer describes how to partition the input tensor. + Returns: + An iterator of `Tensor`s or a list of partioned tensors. + """ + self._check_input_partition_dims(tensor, dims) + if dims is None: + return itertools.repeat(tensor) + else: + output = [tensor] + for axis, dim in enumerate(dims): + if dim > 1: + output = [array_ops.split(x, dim, axis=axis) for x in output] + output = nest.flatten(output) + return output + + def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): + """Tags appropriate XLA sharding attribute to the dequeued tensor. + + Args: + tensor: The dequeued tensor on TPU. + dims: A list of integer describes how the tensor is partitioned. + + Returns: + The same tensor with the xla_sharding attribute. + """ + if dims is None: + return xla_sharding.replicate(tensor) + elif np.prod(dims) == 1: + return xla_sharding.assign_device(tensor, 0) + else: + tile_shape = np.array(tensor.shape.as_list()) // dims + tile_assignment = np.arange(np.prod(dims)).reshape(dims) + return xla_sharding.tile( + tensor=tensor, + tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( + dtype=np.dtype(tensor.dtype.as_numpy_dtype), + shape_tuple=tile_shape), + tile_assignment=tile_assignment) + + def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): + """Tags appropriate XLA sharding attribute to the dequeued tensors. + + Args: + dequeues: A list of dequeued tensors on TPU. + dims: A list of integer describes how the tensor is partitioned. + + Returns: + The same dequeues with appropriate xla_sharding attribute. + """ + nest.assert_shallow_structure(dequeues, dims) + return nest.map_structure_up_to( + dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues, + dims) + + def _get_logical_core(self, core_index): + """Maps the core index to the 3D coordinate within replica. + + The lowest dimension number in computation_shape is the slowest varying + dimension (most major). + + Args: + core_index: An integer represents the core index within replcia. + + Returns: + A tuple with three integers which represents the 3D coordinate. + """ + computation_shape = self._device_assignment.computation_shape + return (core_index // (computation_shape[1] * computation_shape[2]), + core_index % (computation_shape[1] * computation_shape[2]) // + computation_shape[2], core_index % computation_shape[2]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index 894f21d0635ca47d3da1c0d2c3f5c37bac690920..ec682e5829c4df536a043334b74200f0b6259df3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -45,7 +45,7 @@ _TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ ]) -def _query_tpu_system_metadata(master_address, run_config, +def _query_tpu_system_metadata(master_address, cluster_def=None, query_topology=False): """Automatically detects the TPU system metadata in the system.""" tpu_core_count = 0 @@ -61,7 +61,8 @@ def _query_tpu_system_metadata(master_address, run_config, with session_lib.Session( master_address, config=get_session_config_with_timeout( - _PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess: + _PINGING_MASTER_TIMEOUT_IN_MS, + cluster_def)) as sess: devices = sess.list_devices() for device in devices: match = _TPU_DEVICE_REG.match(device.name) @@ -105,7 +106,7 @@ def _query_tpu_system_metadata(master_address, run_config, 'TPU worker has some problems. Available devices: {}'.format( master_address, devices)) - topology = _obtain_topology(master_address, run_config) + topology = _obtain_topology(master_address, cluster_def) metadata = _TPUSystemMetadata( num_cores=tpu_core_count, @@ -127,14 +128,15 @@ def _query_tpu_system_metadata(master_address, run_config, return metadata -def _obtain_topology(master_address, run_config): +def _obtain_topology(master_address, cluster_def): + """Obtains TPU fabric topology.""" try: logging.info('Initializing TPU system (master: %s) to fetch topology ' 'for model parallelism. This might take a while.', master_address) with ops.Graph().as_default(): session_config = get_session_config_with_timeout( - _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config) + _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) with session_lib.Session( master_address, config=session_config) as sess: topology = sess.run(tpu.initialize_system()) @@ -146,11 +148,8 @@ def _obtain_topology(master_address, run_config): master_address)) -def get_session_config_with_timeout(timeout_in_secs, run_config): - cluster_def = None - if run_config.session_config and run_config.session_config.cluster_def.job: - cluster_def = run_config.session_config.cluster_def - +def get_session_config_with_timeout(timeout_in_secs, cluster_def): + """Returns a session given a timeout and a cluster configuration.""" config = config_pb2.ConfigProto( operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) return config diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index f7fd66d33fc0c329db7daaf87373385156d84217..01bac891da7ddf8523e6cc8c99decf4a61aa2741 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -142,9 +142,9 @@ from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import evaluation from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -189,7 +189,7 @@ def wait_for_new_checkpoint(checkpoint_dir, logging.info('Waiting for new checkpoint at %s', checkpoint_dir) stop_time = time.time() + timeout if timeout is not None else None while True: - checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) if checkpoint_path is None or checkpoint_path == last_checkpoint: if stop_time is not None and time.time() + seconds_to_sleep > stop_time: return None diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py index 4877c010fad2c567d26b9674d2904274c0895f55..94cf7788b2bd3bc3fe87eefd599ce88de03042af 100644 --- a/tensorflow/contrib/training/python/training/training_test.py +++ b/tensorflow/contrib/training/python/training/training_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as saver_lib @@ -421,7 +422,7 @@ class TrainTest(test.TestCase): train_op = self.create_train_op() model_variables = variables_lib2.global_variables() - model_path = saver_lib.latest_checkpoint(logdir1) + model_path = checkpoint_management.latest_checkpoint(logdir1) assign_fn = variables_lib.assign_from_checkpoint_fn( model_path, model_variables) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 84555b60dac137478fe48a96d3afebf9a25d67b1..1423c7fbcb227c3a00d74fb62c8e2b547d93c41c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2238,6 +2238,7 @@ cc_library( linkopts = ["-ldl"], deps = [ "//tensorflow/core/platform/default/build_config:jpeg", + "//tensorflow/core/platform/default/build_config:logging", ], ) @@ -2266,6 +2267,7 @@ cc_library( linkopts = ["-ldl"], deps = [ "//tensorflow/core/platform/default/build_config:gif", + "//tensorflow/core/platform/default/build_config:logging", ], ) @@ -2292,6 +2294,7 @@ cc_library( copts = tf_copts(), linkopts = ["-ldl"], deps = [ + "//tensorflow/core/platform/default/build_config:logging", "@png_archive//:png", ], ) @@ -2924,6 +2927,14 @@ tf_cuda_library( ] + tf_additional_device_tracer_deps(), ) +cc_library( + name = "session_ref", + srcs = ["common_runtime/session_ref.cc"], + hdrs = ["common_runtime/session_ref.h"], + copts = tf_copts(), + deps = [":core_cpu_base"], +) + cc_library( name = "gpu_id", hdrs = [ @@ -3225,6 +3236,7 @@ tf_cc_tests( "platform/fingerprint_test.cc", "platform/integral_types_test.cc", "platform/logging_test.cc", + "platform/mutex_test.cc", "platform/net_test.cc", "platform/port_test.cc", "platform/profile_utils/cpu_utils_test.cc", @@ -3482,6 +3494,7 @@ tf_cc_tests( "framework/tensor_shape_test.cc", "framework/tensor_slice_test.cc", "framework/tensor_test.cc", + "framework/tensor_testutil_test.cc", "framework/tensor_util_test.cc", "framework/tracking_allocator_test.cc", "framework/types_test.cc", diff --git a/tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt b/tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt index ad1ada8d717a51ee3a058da5d32ed7bf50375b13..3134fceecabb4969f5d8cf3a67e9288c7ca2a186 100644 --- a/tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt @@ -1,4 +1,4 @@ op { graph_op_name: "Ceil" - summary: "Returns element-wise smallest integer in not less than x." + summary: "Returns element-wise smallest integer not less than x." } diff --git a/tensorflow/core/api_def/base_api/api_def_FilterByLastComponentDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_FilterByLastComponentDataset.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0b41229872347c586dd644f557df2f0dbdcddf5e --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FilterByLastComponentDataset.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "FilterByLastComponentDataset" + visibility: HIDDEN + summary: + "Creates a dataset containing elements of first " + "component of `input_dataset` having true in the last component." +} diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt index ea5669693e09c576d6cf9039846903a317c3b128..dfd199d0128be0225b348f76ba10e0e1dc951b61 100644 --- a/tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt @@ -1,4 +1,4 @@ op { graph_op_name: "IteratorGetNext" - summary: "Gets the next output from the given iterator." + summary: "Gets the next output from the given iterator ." } diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorGetNextAsOptional.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorGetNextAsOptional.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7068336847eacb5521d0e413b8158fe96c67bfaa --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_IteratorGetNextAsOptional.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "IteratorGetNextAsOptional" + summary: "Gets the next output from the given iterator as an Optional variant." +} diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..44336937598dda2816de2c94bfafae3532f63441 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt @@ -0,0 +1,34 @@ +op { + graph_op_name: "MapDefun" + visibility: HIDDEN + in_arg { + name: "arguments" + description: <